Skip to main content

tasm_lib/hashing/algebraic_hasher/
sample_scalars_static_length_static_pointer.rs

1use triton_vm::prelude::*;
2use twenty_first::math::x_field_element::EXTENSION_DEGREE;
3use twenty_first::tip5::RATE;
4
5use crate::hashing::algebraic_hasher::sample_scalars_static_length_dyn_malloc::SampleScalarsStaticLengthDynMalloc;
6use crate::hashing::squeeze_repeatedly_static_number::SqueezeRepeatedlyStaticNumber;
7use crate::prelude::*;
8
9/// Squeeze the sponge to sample a given number of [`XFieldElement`]s. Puts the scalars into
10/// statically allocated memory.
11///
12/// # Panics
13///
14/// Panics if both fields are 0 because the static allocator will be unhappy. :)
15#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
16pub struct SampleScalarsStaticLengthStaticPointer {
17    pub num_elements_to_sample: usize,
18
19    /// Number of additional elements to statically allocate, in number of
20    /// [extension field element][xfe]s.
21    /// Necessary for [`Challenges`][chall].
22    ///
23    /// [chall]: crate::verifier::challenges::new_empty_input_and_output::NewEmptyInputAndOutput
24    /// [xfe]: XFieldElement
25    pub extra_capacity: usize,
26
27    /// Memory address to store the scalars
28    pub scalars_pointer: BFieldElement,
29}
30
31impl BasicSnippet for SampleScalarsStaticLengthStaticPointer {
32    fn parameters(&self) -> Vec<(DataType, String)> {
33        vec![]
34    }
35
36    fn return_values(&self) -> Vec<(DataType, String)> {
37        vec![]
38    }
39
40    fn entrypoint(&self) -> String {
41        format!(
42            "tasmlib_hashing_algebraic_hasher_sample_scalars_static_length_static_pointer_{}_{}",
43            self.num_elements_to_sample, self.scalars_pointer
44        )
45    }
46
47    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
48        assert_eq!(10, RATE, "Code assumes Tip5's RATE is 10");
49        assert_eq!(3, EXTENSION_DEGREE, "Code assumes extension degree 3");
50        let num_squeezes =
51            SampleScalarsStaticLengthDynMalloc::num_squeezes(self.num_elements_to_sample);
52
53        let num_squeezed_words = num_squeezes * RATE;
54        debug_assert!(
55            self.num_elements_to_sample * EXTENSION_DEGREE <= num_squeezed_words,
56            "need {} elements but getting {num_squeezed_words}",
57            self.num_elements_to_sample * EXTENSION_DEGREE,
58        );
59
60        let entrypoint = self.entrypoint();
61        let squeeze_repeatedly_static_number =
62            library.import(Box::new(SqueezeRepeatedlyStaticNumber { num_squeezes }));
63
64        triton_asm!(
65            {entrypoint}:
66                push {self.scalars_pointer}
67                call {squeeze_repeatedly_static_number}
68                return
69        )
70    }
71}
72
73#[cfg(test)]
74pub(crate) mod tests {
75    use twenty_first::util_types::sponge::Sponge;
76
77    use super::*;
78    use crate::prelude::Tip5;
79    use crate::rust_shadowing_helper_functions::array::array_get;
80    use crate::rust_shadowing_helper_functions::array::insert_as_array;
81    use crate::test_helpers::tasm_final_state;
82    use crate::test_prelude::*;
83
84    impl Procedure for SampleScalarsStaticLengthStaticPointer {
85        fn rust_shadow(
86            &self,
87            _: &mut Vec<BFieldElement>,
88            memory: &mut std::collections::HashMap<BFieldElement, BFieldElement>,
89            _: &NonDeterminism,
90            _: &[BFieldElement],
91            sponge: &mut Option<Tip5>,
92        ) -> Result<Vec<BFieldElement>, RustShadowError> {
93            let Some(sponge) = sponge.as_mut() else {
94                return Err(RustShadowError::SpongeUninitialized);
95            };
96            let num_squeezes =
97                SampleScalarsStaticLengthDynMalloc::num_squeezes(self.num_elements_to_sample);
98            let pseudorandomness = (0..num_squeezes)
99                .flat_map(|_| sponge.squeeze().to_vec())
100                .collect_vec();
101            let scalars_pointer = self.scalars_pointer;
102            insert_as_array(scalars_pointer, memory, pseudorandomness);
103
104            Ok(Vec::new())
105        }
106
107        fn pseudorandom_initial_state(
108            &self,
109            seed: [u8; 32],
110            _bench_case: Option<BenchmarkCase>,
111        ) -> ProcedureInitialState {
112            let mut rng = StdRng::from_seed(seed);
113            let stack = self.init_stack_for_isolated_run();
114            let sponge = Tip5 {
115                state: rng.random(),
116            };
117
118            ProcedureInitialState {
119                stack,
120                sponge: Some(sponge),
121                ..Default::default()
122            }
123        }
124
125        fn corner_case_initial_states(&self) -> Vec<ProcedureInitialState> {
126            let freshly_initialized_sponge = ProcedureInitialState {
127                stack: self.init_stack_for_isolated_run(),
128                sponge: Some(Tip5::init()),
129                ..Default::default()
130            };
131
132            vec![freshly_initialized_sponge]
133        }
134    }
135
136    #[macro_rules_attr::apply(test)]
137    fn sample_scalars_static_length_pbt() {
138        for num_elements_to_sample in 0..11 {
139            for extra_capacity in 0..11 {
140                let scalars_pointer: BFieldElement = rand::random();
141                if num_elements_to_sample + extra_capacity == 0 {
142                    continue;
143                }
144                ShadowedProcedure::new(SampleScalarsStaticLengthStaticPointer {
145                    num_elements_to_sample,
146                    extra_capacity,
147                    scalars_pointer,
148                })
149                .test();
150            }
151        }
152    }
153
154    #[macro_rules_attr::apply(proptest)]
155    fn verify_agreement_with_tip5_sample_scalars(
156        #[strategy(0_usize..500)] num_elements_to_sample: usize,
157        #[strategy(0_usize..500)] extra_capacity: usize,
158        #[strategy(arb())] scalars_pointer: BFieldElement,
159        #[strategy(arb())] mut sponge: Tip5,
160    ) {
161        let snippet = SampleScalarsStaticLengthStaticPointer {
162            num_elements_to_sample,
163            extra_capacity,
164            scalars_pointer,
165        };
166        let init_stack = snippet.init_stack_for_isolated_run();
167        let tasm = tasm_final_state(
168            &ShadowedProcedure::new(snippet),
169            &init_stack,
170            &[],
171            NonDeterminism::default(),
172            &Some(sponge.clone()),
173        )
174        .unwrap();
175
176        let scalar_pointer = snippet.scalars_pointer;
177        let read_scalar = |i| array_get(scalar_pointer, i, &tasm.ram, EXTENSION_DEGREE);
178
179        let scalars_from_tip5 = sponge.sample_scalars(num_elements_to_sample);
180        for (i, expected_scalar) in scalars_from_tip5.into_iter().enumerate() {
181            assert_eq!(expected_scalar.coefficients.to_vec(), read_scalar(i));
182        }
183    }
184}
185
186#[cfg(test)]
187mod bench {
188    use super::*;
189    use crate::test_prelude::*;
190    use crate::verifier::challenges::shared::conventional_challenges_pointer;
191
192    #[macro_rules_attr::apply(test)]
193    fn bench_10() {
194        ShadowedProcedure::new(SampleScalarsStaticLengthStaticPointer {
195            num_elements_to_sample: 10,
196            extra_capacity: 4,
197            scalars_pointer: conventional_challenges_pointer(),
198        })
199        .bench();
200    }
201
202    #[macro_rules_attr::apply(test)]
203    fn bench_100() {
204        ShadowedProcedure::new(SampleScalarsStaticLengthStaticPointer {
205            num_elements_to_sample: 100,
206            extra_capacity: 4,
207            scalars_pointer: conventional_challenges_pointer(),
208        })
209        .bench();
210    }
211
212    #[macro_rules_attr::apply(test)]
213    fn bench_63() {
214        ShadowedProcedure::new(SampleScalarsStaticLengthStaticPointer {
215            num_elements_to_sample: 63,
216            extra_capacity: 4,
217            scalars_pointer: conventional_challenges_pointer(),
218        })
219        .bench();
220    }
221}