Skip to main content

tasm_lib/hashing/algebraic_hasher/
sample_indices.rs

1use triton_vm::prelude::*;
2
3use crate::list::length::Length;
4use crate::list::new::New;
5use crate::list::push::Push;
6use crate::prelude::*;
7
8/// Sample n pseudorandom integers between 0 and k. It does this by squeezing the sponge. It is the
9/// caller's responsibility to ensure that the sponge is initialized to the right state.
10#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
11pub struct SampleIndices;
12
13impl BasicSnippet for SampleIndices {
14    fn parameters(&self) -> Vec<(DataType, String)> {
15        vec![
16            (DataType::U32, "number".to_string()),
17            (DataType::U32, "upper_bound".to_string()),
18        ]
19    }
20
21    fn return_values(&self) -> Vec<(DataType, String)> {
22        vec![(
23            DataType::List(Box::new(DataType::U32)),
24            "*indices".to_string(),
25        )]
26    }
27
28    fn entrypoint(&self) -> String {
29        "tasmlib_hashing_algebraic_hasher_sample_indices".into()
30    }
31
32    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
33        let entrypoint = self.entrypoint();
34        let main_loop = format!("{entrypoint}_main_loop");
35        let then_reduce_and_save = format!("{entrypoint}_then_reduce_and_save");
36        let else_drop_tip = format!("{entrypoint}_else_drop_tip");
37
38        let new_list = library.import(Box::new(New));
39        let length = library.import(Box::new(Length));
40        let push_element = library.import(Box::new(Push::new(DataType::U32)));
41
42        let if_can_sample = triton_asm! (
43            // BEFORE: _ prn number upper_bound *indices
44            // AFTER:  _ prn number upper_bound *indices ~can_use can_use
45            dup 0 call {length}         // _ prn number upper_bound *indices length
46            dup 3 eq                    // _ prn number upper_bound *indices length==number
47            push 0 eq                   // _ prn number upper_bound *indices length!=number
48            dup 4 push -1 eq            // _ prn number upper_bound *indices length!=number prn==max
49            push 0 eq                   // _ prn number upper_bound *indices length!=number prn!=max
50            mul                         // _ prn number upper_bound *indices length!=number&&prn!=max
51            dup 0                       // _ prn number upper_bound *indices length!=number&&prn!=max length!=number&&prn!=max
52            push 0 eq                   // _ prn number upper_bound *indices length!=number&&prn!=max ~(length!=number&&prn!=max)
53            swap 1                      // _ prn number upper_bound *indices ~(length!=number&&prn!=max) length!=number&&prn!=max
54        );
55
56        triton_asm! (
57            // BEFORE: _ number upper_bound
58            // AFTER:  _ *indices
59            {entrypoint}:
60                call {new_list}         // _ number upper_bound *indices
61
62                // prepare and call main while lop
63                swap 1                  // _ number *indices upper_bound
64                push -1 add             // _ number *indices upper_bound-1
65                swap 1                  // _ number upper_bound-1 *indices
66                call {main_loop}        // _ number upper_bound-1 *indices
67
68                // clean up and return
69                swap 2 pop 2
70                return
71
72            // INVARIANT: _ number upper_bound-1 *indices
73            {main_loop}:
74                // evaluate termination condition
75                dup 0 call {length}     // _ number upper_bound-1 *indices length
76                dup 3 eq                // _ number upper_bound-1 *indices length==number
77                skiz return             // _ number upper_bound-1 *indices
78
79                // we need to squeeze so squeeze
80                sponge_squeeze          // _ number upper_bound-1 *indices [prn]
81
82                // reject or reduce-and-store
83                dup 12 dup 12 dup 12    // _ number upper_bound-1 *indices [prn] number upper_bound-1 *indices
84
85                {&if_can_sample} skiz call {then_reduce_and_save} skiz call {else_drop_tip}
86                {&if_can_sample} skiz call {then_reduce_and_save} skiz call {else_drop_tip}
87                {&if_can_sample} skiz call {then_reduce_and_save} skiz call {else_drop_tip}
88                {&if_can_sample} skiz call {then_reduce_and_save} skiz call {else_drop_tip}
89                {&if_can_sample} skiz call {then_reduce_and_save} skiz call {else_drop_tip}
90                {&if_can_sample} skiz call {then_reduce_and_save} skiz call {else_drop_tip}
91                {&if_can_sample} skiz call {then_reduce_and_save} skiz call {else_drop_tip}
92                {&if_can_sample} skiz call {then_reduce_and_save} skiz call {else_drop_tip}
93                {&if_can_sample} skiz call {then_reduce_and_save} skiz call {else_drop_tip}
94                {&if_can_sample} skiz call {then_reduce_and_save} skiz call {else_drop_tip}
95                                        // _ number upper_bound-1 *indices number upper_bound-1 *indices
96
97                // return to invariant and repeat
98                pop 3                   // _ number upper_bound-1 *indices
99                recurse
100
101            // BEFORE: _ prn number upper_bound-1 *indices 0
102            // AFTER:  _ number upper_bound-1 *indices 0
103            {then_reduce_and_save}:
104                pop 1                   // _ prn number upper_bound-1 *indices
105                swap 2 swap 3           // _ number *indices upper_bound-1 prn
106                split                   // _ number *indices upper_bound-1 hi lo
107                dup 2 and               // _ number *indices upper_bound-1 hi index
108                swap 1 pop 1            // _ number *indices upper_bound-1 index
109
110                swap 1 swap 2 swap 1    // _ number upper_bound-1 *indices index
111                dup 1 swap 1            // _ number upper_bound-1 *indices *indices index
112                call {push_element}
113
114                push 0
115                return
116
117            // BEFORE: _ prn number upper_bound-1 *indices
118            // AFTER:  _ number upper_bound-1 *indices
119            {else_drop_tip}:
120                swap 2 swap 3           // _ number *indices upper_bound-1 prn
121                pop 1 swap 1            // _ number upper_bound-1 *indices
122                return
123
124        )
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use crate::empty_stack;
132    use crate::rust_shadowing_helper_functions;
133    use crate::test_prelude::*;
134
135    impl Procedure for SampleIndices {
136        fn rust_shadow(
137            &self,
138            stack: &mut Vec<BFieldElement>,
139            memory: &mut HashMap<BFieldElement, BFieldElement>,
140            _: &NonDeterminism,
141            _: &[BFieldElement],
142            sponge: &mut Option<Tip5>,
143        ) -> Result<Vec<BFieldElement>, RustShadowError> {
144            let Some(sponge) = sponge.as_mut() else {
145                return Err(RustShadowError::SpongeUninitialized);
146            };
147
148            // collect upper bound and number from stack
149            let upper_bound = stack.pop().ok_or(RustShadowError::StackUnderflow)?.value() as u32;
150            let number = stack.pop().ok_or(RustShadowError::StackUnderflow)?.value() as usize;
151
152            println!("sampling {number} indices between 0 and {upper_bound}");
153            println!("sponge before: {}", sponge.state.iter().join(","));
154
155            let indices = sponge.sample_indices(upper_bound, number);
156
157            // allocate memory for list
158            let list_pointer =
159                rust_shadowing_helper_functions::dyn_malloc::dynamic_allocator(memory);
160            rust_shadowing_helper_functions::list::list_new(list_pointer, memory);
161
162            // store all indices
163            for index in indices.iter() {
164                rust_shadowing_helper_functions::list::list_push(
165                    list_pointer,
166                    vec![BFieldElement::new(*index as u64)],
167                    memory,
168                )?;
169            }
170            println!("sponge after: {}", sponge.state.iter().join(","));
171
172            stack.push(list_pointer);
173
174            Ok(Vec::new())
175        }
176
177        fn pseudorandom_initial_state(
178            &self,
179            seed: [u8; 32],
180            bench_case: Option<BenchmarkCase>,
181        ) -> ProcedureInitialState {
182            let mut rng = StdRng::from_seed(seed);
183            let number = if let Some(case) = bench_case {
184                match case {
185                    // For FRI num_collinearity checks is 80 for expansion factor 4
186                    BenchmarkCase::CommonCase => 40,
187
188                    // For FRI num_collinearity checks is 40 for expansion factor 8
189                    BenchmarkCase::WorstCase => 80,
190                }
191            } else {
192                rng.random_range(0..20)
193            };
194            let upper_bound = if let Some(case) = bench_case {
195                match case {
196                    BenchmarkCase::CommonCase => 1 << 12,
197                    BenchmarkCase::WorstCase => 1 << 23,
198                }
199            } else {
200                1 << rng.random_range(0..20)
201            };
202
203            let mut stack = empty_stack();
204            stack.push(BFieldElement::new(number as u64));
205            stack.push(BFieldElement::new(upper_bound as u64));
206
207            let public_input: Vec<BFieldElement> = vec![];
208            let state = Tip5 {
209                state: rng.random(),
210            };
211
212            ProcedureInitialState {
213                stack,
214                nondeterminism: NonDeterminism::default(),
215                public_input,
216                sponge: Some(state),
217            }
218        }
219    }
220
221    #[macro_rules_attr::apply(test)]
222    fn test() {
223        ShadowedProcedure::new(SampleIndices).test();
224    }
225}
226
227#[cfg(test)]
228mod bench {
229    use super::*;
230    use crate::test_prelude::*;
231
232    #[macro_rules_attr::apply(test)]
233    fn benchmark() {
234        ShadowedProcedure::new(SampleIndices).bench();
235    }
236}