tasm_lib/list/
set.rs

1use triton_vm::prelude::*;
2
3use crate::list::get::Get;
4use crate::list::length::Length;
5use crate::prelude::*;
6
7/// Write an element to a list. Performs bounds check.
8///
9/// Only supports lists with [statically sized](BFieldCodec::static_length)
10/// elements.
11///
12/// ### Behavior
13///
14/// ```text
15/// BEFORE: _ [element: ElementType] *list [index: u32]
16/// AFTER:  _
17/// ```
18///
19/// ### Preconditions
20///
21/// - the argument `*list` points to a properly [`BFieldCodec`]-encoded list
22/// - all input arguments are properly [`BFieldCodec`] encoded
23///
24/// ### Postconditions
25///
26/// None.
27#[derive(Debug, Clone, Eq, PartialEq, Hash)]
28pub struct Set {
29    element_type: DataType,
30}
31
32impl Set {
33    pub const INDEX_OUT_OF_BOUNDS_ERROR_ID: i128 = 390;
34
35    /// Any part of the list is outside the allocated memory page.
36    /// See the [memory convention][crate::memory] for more details.
37    pub const MEM_PAGE_ACCESS_VIOLATION_ERROR_ID: i128 = 391;
38
39    /// # Panics
40    ///
41    /// Panics if the element has [dynamic length][BFieldCodec::static_length], or
42    /// if the static length is 0.
43    pub fn new(element_type: DataType) -> Self {
44        Get::assert_element_type_is_supported(&element_type);
45
46        Self { element_type }
47    }
48}
49
50impl BasicSnippet for Set {
51    fn inputs(&self) -> Vec<(DataType, String)> {
52        let element_type = self.element_type.clone();
53        let list_type = DataType::List(Box::new(element_type.clone()));
54        let index_type = DataType::U32;
55
56        vec![
57            (element_type, "element".to_string()),
58            (list_type, "*list".to_string()),
59            (index_type, "index".to_string()),
60        ]
61    }
62
63    fn outputs(&self) -> Vec<(DataType, String)> {
64        vec![]
65    }
66
67    fn entrypoint(&self) -> String {
68        let element_type = self.element_type.label_friendly_name();
69        format!("tasmlib_list_set_element___{element_type}")
70    }
71
72    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
73        let list_length = library.import(Box::new(Length));
74        let mul_with_element_size = match self.element_type.stack_size() {
75            1 => triton_asm!(), // no-op
76            n => triton_asm!(push {n} mul),
77        };
78        let add_element_size_minus_1 = match self.element_type.stack_size() {
79            1 => triton_asm!(), // no-op
80            n => triton_asm!(addi {n - 1}),
81        };
82
83        triton_asm!(
84            // BEFORE: _ [element: self.element_type] *list index
85            // AFTER:  _
86            {self.entrypoint()}:
87                /* assert access is in bounds */
88                dup 1
89                call {list_length}  // _ [element] *list index len
90                dup 1
91                lt                  // _ [element] *list index (index < len)
92                assert error_id {Self::INDEX_OUT_OF_BOUNDS_ERROR_ID}
93                                    // _ [element] *list index
94
95                {&mul_with_element_size}
96                                    // _ [element] *list offset_for_previous_elements
97                addi 1              // _ [element] *list offset
98
99                /* assert access is within one memory page */
100                dup 0
101                {&add_element_size_minus_1}
102                                    // _ [element] *list offset highest_word_idx
103                split
104                pop 1
105                push 0
106                eq
107                assert error_id {Self::MEM_PAGE_ACCESS_VIOLATION_ERROR_ID}
108                                    // _ [element] *list offset_including_list_metadata
109
110                add                 // _ [element] *element
111                {&self.element_type.write_value_to_memory_pop_pointer()}
112                                    // _
113                return
114        )
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use proptest::collection::vec;
121    use triton_vm::error::OpStackError::FailedU32Conversion;
122
123    use super::*;
124    use crate::U32_TO_USIZE_ERR;
125    use crate::rust_shadowing_helper_functions::list::insert_random_list;
126    use crate::rust_shadowing_helper_functions::list::list_set;
127    use crate::test_helpers::negative_test;
128    use crate::test_prelude::*;
129
130    impl Set {
131        fn set_up_initial_state(
132            &self,
133            list_length: usize,
134            index: usize,
135            list_pointer: BFieldElement,
136            element: Vec<BFieldElement>,
137        ) -> FunctionInitialState {
138            let mut memory = HashMap::default();
139            insert_random_list(&self.element_type, list_pointer, list_length, &mut memory);
140
141            let mut stack = self.init_stack_for_isolated_run();
142            stack.extend(element.into_iter().rev());
143            stack.push(list_pointer);
144            stack.push(bfe!(index));
145
146            FunctionInitialState { stack, memory }
147        }
148    }
149
150    impl Function for Set {
151        fn rust_shadow(
152            &self,
153            stack: &mut Vec<BFieldElement>,
154            memory: &mut HashMap<BFieldElement, BFieldElement>,
155        ) {
156            let index = pop_encodable::<u32>(stack);
157            let list_pointer = stack.pop().unwrap();
158            let element = (0..self.element_type.stack_size())
159                .map(|_| stack.pop().unwrap())
160                .collect_vec();
161
162            let index = index.try_into().expect(U32_TO_USIZE_ERR);
163            list_set(list_pointer, index, element, memory);
164        }
165
166        fn pseudorandom_initial_state(
167            &self,
168            seed: [u8; 32],
169            bench_case: Option<BenchmarkCase>,
170        ) -> FunctionInitialState {
171            let mut rng = StdRng::from_seed(seed);
172            let (list_length, index, list_pointer) = Get::random_len_idx_ptr(bench_case, &mut rng);
173            let element = self.element_type.seeded_random_element(&mut rng);
174
175            self.set_up_initial_state(list_length, index, list_pointer, element)
176        }
177    }
178
179    #[test]
180    fn rust_shadow() {
181        for ty in [
182            DataType::Bool,
183            DataType::Bfe,
184            DataType::U32,
185            DataType::U64,
186            DataType::Xfe,
187            DataType::Digest,
188        ] {
189            ShadowedFunction::new(Set::new(ty)).test();
190        }
191    }
192
193    #[proptest]
194    fn out_of_bounds_access_crashes_vm(
195        #[strategy(0_usize..=1_000)] list_length: usize,
196        #[strategy(#list_length..1 << 32)] index: usize,
197        #[strategy(arb())] list_pointer: BFieldElement,
198        #[strategy(vec(arb(), 1))] element: Vec<BFieldElement>,
199    ) {
200        let set = Set::new(DataType::Bfe);
201        let initial_state = set.set_up_initial_state(list_length, index, list_pointer, element);
202        test_assertion_failure(
203            &ShadowedFunction::new(set),
204            initial_state.into(),
205            &[Set::INDEX_OUT_OF_BOUNDS_ERROR_ID],
206        );
207    }
208
209    #[proptest]
210    fn too_large_indices_crash_vm(
211        #[strategy(1_usize << 32..)] index: usize,
212        #[strategy(arb())] list_pointer: BFieldElement,
213        #[strategy(vec(arb(), 1))] element: Vec<BFieldElement>,
214    ) {
215        let list_length = 0;
216        let set = Set::new(DataType::Bfe);
217        let initial_state = set.set_up_initial_state(list_length, index, list_pointer, element);
218        let expected_error = InstructionError::OpStackError(FailedU32Conversion(bfe!(index)));
219        negative_test(
220            &ShadowedFunction::new(set),
221            initial_state.into(),
222            &[expected_error],
223        );
224    }
225
226    /// See mirroring test for [`Get`] for an explanation.
227    #[proptest(cases = 100)]
228    fn too_large_lists_crash_vm(
229        #[strategy(1_u64 << 22..1 << 32)] list_length: u64,
230        #[strategy((1 << 22) - 1..#list_length)] index: u64,
231        #[strategy(arb())] list_pointer: BFieldElement,
232    ) {
233        // spare host machine RAM: pretend every element is all-zeros
234        let mut memory = HashMap::default();
235        memory.insert(list_pointer, bfe!(list_length));
236
237        // type with a large stack size in Triton VM without breaking the host machine
238        let tuple_ty = DataType::Tuple(vec![DataType::Bfe; 1 << 10]);
239        let set = Set::new(tuple_ty);
240
241        // no element on stack: stack underflow implies things have gone wrong already
242        let mut stack = set.init_stack_for_isolated_run();
243        stack.push(list_pointer);
244        stack.push(bfe!(index));
245        let initial_state = AccessorInitialState { stack, memory };
246
247        test_assertion_failure(
248            &ShadowedFunction::new(set),
249            initial_state.into(),
250            &[Set::MEM_PAGE_ACCESS_VIOLATION_ERROR_ID],
251        );
252    }
253}
254
255#[cfg(test)]
256mod benches {
257    use super::*;
258    use crate::test_prelude::*;
259
260    #[test]
261    fn benchmark() {
262        ShadowedFunction::new(Set::new(DataType::Digest)).bench();
263    }
264}