Skip to main content

tasm_lib/hashing/algebraic_hasher/
sample_scalars_static_length_dyn_malloc.rs

1use triton_vm::prelude::*;
2use twenty_first::math::x_field_element::EXTENSION_DEGREE;
3use twenty_first::tip5::RATE;
4
5use crate::data_type::ArrayType;
6use crate::hashing::squeeze_repeatedly_static_number::SqueezeRepeatedlyStaticNumber;
7use crate::prelude::*;
8
9/// Squeeze the sponge to sample a given number of `XFieldElement`s.
10#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
11pub struct SampleScalarsStaticLengthDynMalloc {
12    pub num_elements: usize,
13}
14
15impl SampleScalarsStaticLengthDynMalloc {
16    pub(super) fn num_squeezes(num_elements: usize) -> usize {
17        (num_elements * EXTENSION_DEGREE).div_ceil(RATE)
18    }
19}
20
21impl BasicSnippet for SampleScalarsStaticLengthDynMalloc {
22    fn parameters(&self) -> Vec<(DataType, String)> {
23        vec![]
24    }
25
26    fn return_values(&self) -> Vec<(DataType, String)> {
27        vec![(
28            DataType::Array(Box::new(ArrayType {
29                element_type: DataType::Xfe,
30                length: self.num_elements,
31            })),
32            "*scalars".to_owned(),
33        )]
34    }
35
36    fn entrypoint(&self) -> String {
37        format!(
38            "tasmlib_hashing_algebraic_hasher_sample_scalars_static_length_dyn_malloc_{}",
39            self.num_elements
40        )
41    }
42
43    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
44        assert_eq!(10, RATE, "Code assumes Tip5's RATE is 10");
45        assert_eq!(3, EXTENSION_DEGREE, "Code assumes extension degree 3");
46        let num_squeezes = Self::num_squeezes(self.num_elements);
47
48        debug_assert!(
49            self.num_elements * EXTENSION_DEGREE <= num_squeezes * RATE,
50            "need {} elements but getting {}",
51            self.num_elements * EXTENSION_DEGREE,
52            num_squeezes * RATE
53        );
54
55        let entrypoint = self.entrypoint();
56        let squeeze_repeatedly_static_number =
57            library.import(Box::new(SqueezeRepeatedlyStaticNumber { num_squeezes }));
58        let dyn_malloc = library.import(Box::new(DynMalloc));
59
60        triton_asm!(
61            {entrypoint}:
62                // _
63
64                // Allocate memory for return-array
65                call {dyn_malloc}
66                // _ *array
67
68                dup 0
69                // _ *array *array
70
71                // squeeze
72                call {squeeze_repeatedly_static_number}
73                // _ *array
74
75                return
76        )
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use twenty_first::prelude::*;
83
84    use super::*;
85    use crate::memory::dyn_malloc::DYN_MALLOC_FIRST_ADDRESS;
86    use crate::rust_shadowing_helper_functions;
87    use crate::test_helpers::tasm_final_state;
88    use crate::test_prelude::*;
89
90    impl Procedure for SampleScalarsStaticLengthDynMalloc {
91        fn rust_shadow(
92            &self,
93            stack: &mut Vec<BFieldElement>,
94            memory: &mut HashMap<BFieldElement, BFieldElement>,
95            _nondeterminism: &NonDeterminism,
96            _public_input: &[BFieldElement],
97            sponge: &mut Option<Tip5>,
98        ) -> Result<Vec<BFieldElement>, RustShadowError> {
99            let Some(sponge) = sponge.as_mut() else {
100                return Err(RustShadowError::SpongeUninitialized);
101            };
102            let num_squeezes = Self::num_squeezes(self.num_elements);
103            let pseudorandomness = (0..num_squeezes)
104                .flat_map(|_| sponge.squeeze().to_vec())
105                .collect_vec();
106            let scalars_pointer = DYN_MALLOC_FIRST_ADDRESS;
107            stack.push(scalars_pointer);
108
109            // store all pseudorandomness (not just sampled scalars) to memory
110            for (i, pr) in pseudorandomness.iter().enumerate() {
111                memory.insert(BFieldElement::new(i as u64) + scalars_pointer, *pr);
112            }
113
114            Ok(Vec::new())
115        }
116
117        fn pseudorandom_initial_state(
118            &self,
119            seed: [u8; 32],
120            _: Option<BenchmarkCase>,
121        ) -> ProcedureInitialState {
122            let mut rng = StdRng::from_seed(seed);
123            let stack = self.init_stack_for_isolated_run();
124            let sponge = Tip5 {
125                state: rng.random(),
126            };
127
128            ProcedureInitialState {
129                stack,
130                sponge: Some(sponge),
131                ..Default::default()
132            }
133        }
134    }
135
136    #[macro_rules_attr::apply(test)]
137    fn sample_scalars_static_length_pbt() {
138        for i in 0..100 {
139            ShadowedProcedure::new(SampleScalarsStaticLengthDynMalloc { num_elements: i }).test();
140        }
141    }
142
143    #[macro_rules_attr::apply(test)]
144    fn verify_agreement_with_tip5_sample_scalars() {
145        let empty_sponge = Tip5::init();
146        let mut non_empty_sponge = Tip5::init();
147        non_empty_sponge.absorb([BFieldElement::new(100); Tip5::RATE]);
148
149        for init_sponge in [empty_sponge, non_empty_sponge] {
150            for num_elements in 0..30 {
151                let snippet = SampleScalarsStaticLengthDynMalloc { num_elements };
152                let init_stack = snippet.init_stack_for_isolated_run();
153                let tasm = tasm_final_state(
154                    &ShadowedProcedure::new(snippet),
155                    &init_stack,
156                    &[],
157                    NonDeterminism::default(),
158                    &Some(init_sponge.clone()),
159                )
160                .unwrap();
161
162                let final_ram = tasm.ram;
163                let snippet_output_scalar_pointer =
164                    tasm.op_stack.stack[tasm.op_stack.stack.len() - 1];
165
166                let scalars_from_tip5 =
167                    Tip5::sample_scalars(&mut init_sponge.clone(), num_elements);
168
169                for (i, expected_scalar) in scalars_from_tip5.into_iter().enumerate() {
170                    assert_eq!(
171                        expected_scalar.coefficients.to_vec(),
172                        rust_shadowing_helper_functions::array::array_get(
173                            snippet_output_scalar_pointer,
174                            i,
175                            &final_ram,
176                            EXTENSION_DEGREE,
177                        )
178                    );
179                }
180            }
181        }
182    }
183}
184
185#[cfg(test)]
186mod bench {
187    use super::*;
188    use crate::test_prelude::*;
189
190    #[macro_rules_attr::apply(test)]
191    fn bench_10() {
192        ShadowedProcedure::new(SampleScalarsStaticLengthDynMalloc { num_elements: 10 }).bench();
193    }
194
195    #[macro_rules_attr::apply(test)]
196    fn bench_100() {
197        ShadowedProcedure::new(SampleScalarsStaticLengthDynMalloc { num_elements: 100 }).bench();
198    }
199
200    #[macro_rules_attr::apply(test)]
201    fn bench_63() {
202        ShadowedProcedure::new(SampleScalarsStaticLengthDynMalloc { num_elements: 63 }).bench();
203    }
204}