Skip to main content

tasm_lib/hashing/algebraic_hasher/
sample_indices.rs

1use triton_vm::prelude::*;
2
3use crate::list::length::Length;
4use crate::list::new::New;
5use crate::list::push::Push;
6use crate::prelude::*;
7
8/// Sample n pseudorandom integers between 0 and k. It does this by squeezing the sponge. It is the
9/// caller's responsibility to ensure that the sponge is initialized to the right state.
10///
11/// **Precondition: `upper_bound` must be a power of two.** Each squeezed word is
12/// reduced into range with a bitwise `AND (upper_bound - 1)`, which equals
13/// `word mod upper_bound` only when `upper_bound` is a power of two. For any other
14/// `upper_bound` the mask is not contiguous: the snippet does **not** crash, but it
15/// silently returns biased indices drawn from `[0, 2^floor(log2(upper_bound)))`
16/// rather than a uniform sample over `[0, upper_bound)`. This mirrors the reduction
17/// in `Tip5::sample_indices`, which `assert`s the same precondition; this snippet
18/// deliberately omits that assertion (every in-tree caller passes a power-of-two
19/// bound — e.g. the FRI domain length or the mutator-set window size), so the
20/// caller is responsible for upholding it. Uniformity is not guaranteed otherwise.
21#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
22pub struct SampleIndices;
23
24impl BasicSnippet for SampleIndices {
25    fn parameters(&self) -> Vec<(DataType, String)> {
26        vec![
27            (DataType::U32, "number".to_string()),
28            (DataType::U32, "upper_bound".to_string()),
29        ]
30    }
31
32    fn return_values(&self) -> Vec<(DataType, String)> {
33        vec![(
34            DataType::List(Box::new(DataType::U32)),
35            "*indices".to_string(),
36        )]
37    }
38
39    fn entrypoint(&self) -> String {
40        "tasmlib_hashing_algebraic_hasher_sample_indices".into()
41    }
42
43    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
44        let entrypoint = self.entrypoint();
45        let main_loop = format!("{entrypoint}_main_loop");
46        let then_reduce_and_save = format!("{entrypoint}_then_reduce_and_save");
47        let else_drop_tip = format!("{entrypoint}_else_drop_tip");
48
49        let new_list = library.import(Box::new(New));
50        let length = library.import(Box::new(Length));
51        let push_element = library.import(Box::new(Push::new(DataType::U32)));
52
53        let if_can_sample = triton_asm! (
54            // BEFORE: _ prn number upper_bound *indices
55            // AFTER:  _ prn number upper_bound *indices ~can_use can_use
56            dup 0 call {length}         // _ prn number upper_bound *indices length
57            dup 3 eq                    // _ prn number upper_bound *indices length==number
58            push 0 eq                   // _ prn number upper_bound *indices length!=number
59            dup 4 push -1 eq            // _ prn number upper_bound *indices length!=number prn==max
60            push 0 eq                   // _ prn number upper_bound *indices length!=number prn!=max
61            mul                         // _ prn number upper_bound *indices length!=number&&prn!=max
62            dup 0                       // _ prn number upper_bound *indices length!=number&&prn!=max length!=number&&prn!=max
63            push 0 eq                   // _ prn number upper_bound *indices length!=number&&prn!=max ~(length!=number&&prn!=max)
64            swap 1                      // _ prn number upper_bound *indices ~(length!=number&&prn!=max) length!=number&&prn!=max
65        );
66
67        triton_asm! (
68            // BEFORE: _ number upper_bound
69            // AFTER:  _ *indices
70            {entrypoint}:
71                call {new_list}         // _ number upper_bound *indices
72
73                // prepare and call main while lop
74                swap 1                  // _ number *indices upper_bound
75                push -1 add             // _ number *indices upper_bound-1
76                swap 1                  // _ number upper_bound-1 *indices
77                call {main_loop}        // _ number upper_bound-1 *indices
78
79                // clean up and return
80                swap 2 pop 2
81                return
82
83            // INVARIANT: _ number upper_bound-1 *indices
84            {main_loop}:
85                // evaluate termination condition
86                dup 0 call {length}     // _ number upper_bound-1 *indices length
87                dup 3 eq                // _ number upper_bound-1 *indices length==number
88                skiz return             // _ number upper_bound-1 *indices
89
90                // we need to squeeze so squeeze
91                sponge_squeeze          // _ number upper_bound-1 *indices [prn]
92
93                // reject or reduce-and-store
94                dup 12 dup 12 dup 12    // _ number upper_bound-1 *indices [prn] number upper_bound-1 *indices
95
96                {&if_can_sample} skiz call {then_reduce_and_save} skiz call {else_drop_tip}
97                {&if_can_sample} skiz call {then_reduce_and_save} skiz call {else_drop_tip}
98                {&if_can_sample} skiz call {then_reduce_and_save} skiz call {else_drop_tip}
99                {&if_can_sample} skiz call {then_reduce_and_save} skiz call {else_drop_tip}
100                {&if_can_sample} skiz call {then_reduce_and_save} skiz call {else_drop_tip}
101                {&if_can_sample} skiz call {then_reduce_and_save} skiz call {else_drop_tip}
102                {&if_can_sample} skiz call {then_reduce_and_save} skiz call {else_drop_tip}
103                {&if_can_sample} skiz call {then_reduce_and_save} skiz call {else_drop_tip}
104                {&if_can_sample} skiz call {then_reduce_and_save} skiz call {else_drop_tip}
105                {&if_can_sample} skiz call {then_reduce_and_save} skiz call {else_drop_tip}
106                                        // _ number upper_bound-1 *indices number upper_bound-1 *indices
107
108                // return to invariant and repeat
109                pop 3                   // _ number upper_bound-1 *indices
110                recurse
111
112            // BEFORE: _ prn number upper_bound-1 *indices 0
113            // AFTER:  _ number upper_bound-1 *indices 0
114            {then_reduce_and_save}:
115                pop 1                   // _ prn number upper_bound-1 *indices
116                swap 2 swap 3           // _ number *indices upper_bound-1 prn
117                split                   // _ number *indices upper_bound-1 hi lo
118                dup 2 and               // _ number *indices upper_bound-1 hi index
119                swap 1 pop 1            // _ number *indices upper_bound-1 index
120
121                swap 1 swap 2 swap 1    // _ number upper_bound-1 *indices index
122                dup 1 swap 1            // _ number upper_bound-1 *indices *indices index
123                call {push_element}
124
125                push 0
126                return
127
128            // BEFORE: _ prn number upper_bound-1 *indices
129            // AFTER:  _ number upper_bound-1 *indices
130            {else_drop_tip}:
131                swap 2 swap 3           // _ number *indices upper_bound-1 prn
132                pop 1 swap 1            // _ number upper_bound-1 *indices
133                return
134
135        )
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142    use crate::empty_stack;
143    use crate::rust_shadowing_helper_functions;
144    use crate::test_prelude::*;
145
146    impl Procedure for SampleIndices {
147        fn rust_shadow(
148            &self,
149            stack: &mut Vec<BFieldElement>,
150            memory: &mut HashMap<BFieldElement, BFieldElement>,
151            _: &NonDeterminism,
152            _: &[BFieldElement],
153            sponge: &mut Option<Tip5>,
154        ) -> Result<Vec<BFieldElement>, RustShadowError> {
155            let Some(sponge) = sponge.as_mut() else {
156                return Err(RustShadowError::SpongeUninitialized);
157            };
158
159            // collect upper bound and number from stack
160            let upper_bound = stack.pop().ok_or(RustShadowError::StackUnderflow)?.value() as u32;
161            let number = stack.pop().ok_or(RustShadowError::StackUnderflow)?.value() as usize;
162
163            println!("sampling {number} indices between 0 and {upper_bound}");
164            println!("sponge before: {}", sponge.state.iter().join(","));
165
166            let indices = sponge.sample_indices(upper_bound, number);
167
168            // allocate memory for list
169            let list_pointer =
170                rust_shadowing_helper_functions::dyn_malloc::dynamic_allocator(memory);
171            rust_shadowing_helper_functions::list::list_new(list_pointer, memory);
172
173            // store all indices
174            for index in indices.iter() {
175                rust_shadowing_helper_functions::list::list_push(
176                    list_pointer,
177                    vec![BFieldElement::new(*index as u64)],
178                    memory,
179                )?;
180            }
181            println!("sponge after: {}", sponge.state.iter().join(","));
182
183            stack.push(list_pointer);
184
185            Ok(Vec::new())
186        }
187
188        fn pseudorandom_initial_state(
189            &self,
190            seed: [u8; 32],
191            bench_case: Option<BenchmarkCase>,
192        ) -> ProcedureInitialState {
193            let mut rng = StdRng::from_seed(seed);
194            let number = if let Some(case) = bench_case {
195                match case {
196                    // For FRI num_collinearity checks is 80 for expansion factor 4
197                    BenchmarkCase::CommonCase => 40,
198
199                    // For FRI num_collinearity checks is 40 for expansion factor 8
200                    BenchmarkCase::WorstCase => 80,
201                }
202            } else {
203                rng.random_range(0..20)
204            };
205            let upper_bound = if let Some(case) = bench_case {
206                match case {
207                    BenchmarkCase::CommonCase => 1 << 12,
208                    BenchmarkCase::WorstCase => 1 << 23,
209                }
210            } else {
211                1 << rng.random_range(0..20)
212            };
213
214            let mut stack = empty_stack();
215            stack.push(BFieldElement::new(number as u64));
216            stack.push(BFieldElement::new(upper_bound as u64));
217
218            let public_input: Vec<BFieldElement> = vec![];
219            let state = Tip5 {
220                state: rng.random(),
221            };
222
223            ProcedureInitialState {
224                stack,
225                nondeterminism: NonDeterminism::default(),
226                public_input,
227                sponge: Some(state),
228            }
229        }
230    }
231
232    #[macro_rules_attr::apply(test)]
233    fn test() {
234        ShadowedProcedure::new(SampleIndices).test();
235    }
236}
237
238#[cfg(test)]
239mod bench {
240    use super::*;
241    use crate::test_prelude::*;
242
243    #[macro_rules_attr::apply(test)]
244    fn benchmark() {
245        ShadowedProcedure::new(SampleIndices).bench();
246    }
247}