1use triton_vm::prelude::*;
2
3use crate::list::get::Get;
4use crate::list::length::Length;
5use crate::prelude::*;
6
7#[derive(Debug, Clone, Eq, PartialEq, Hash)]
28pub struct Set {
29 element_type: DataType,
30}
31
32impl Set {
33 pub const INDEX_OUT_OF_BOUNDS_ERROR_ID: i128 = 390;
34
35 pub const MEM_PAGE_ACCESS_VIOLATION_ERROR_ID: i128 = 391;
38
39 pub fn new(element_type: DataType) -> Self {
44 Get::assert_element_type_is_supported(&element_type);
45
46 Self { element_type }
47 }
48}
49
50impl BasicSnippet for Set {
51 fn inputs(&self) -> Vec<(DataType, String)> {
52 let element_type = self.element_type.clone();
53 let list_type = DataType::List(Box::new(element_type.clone()));
54 let index_type = DataType::U32;
55
56 vec![
57 (element_type, "element".to_string()),
58 (list_type, "*list".to_string()),
59 (index_type, "index".to_string()),
60 ]
61 }
62
63 fn outputs(&self) -> Vec<(DataType, String)> {
64 vec![]
65 }
66
67 fn entrypoint(&self) -> String {
68 let element_type = self.element_type.label_friendly_name();
69 format!("tasmlib_list_set_element___{element_type}")
70 }
71
72 fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
73 let list_length = library.import(Box::new(Length));
74 let mul_with_element_size = match self.element_type.stack_size() {
75 1 => triton_asm!(), n => triton_asm!(push {n} mul),
77 };
78 let add_element_size_minus_1 = match self.element_type.stack_size() {
79 1 => triton_asm!(), n => triton_asm!(addi {n - 1}),
81 };
82
83 triton_asm!(
84 {self.entrypoint()}:
87 dup 1
89 call {list_length} dup 1
91 lt assert error_id {Self::INDEX_OUT_OF_BOUNDS_ERROR_ID}
93 {&mul_with_element_size}
96 addi 1 dup 0
101 {&add_element_size_minus_1}
102 split
104 pop 1
105 push 0
106 eq
107 assert error_id {Self::MEM_PAGE_ACCESS_VIOLATION_ERROR_ID}
108 add {&self.element_type.write_value_to_memory_pop_pointer()}
112 return
114 )
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use proptest::collection::vec;
121 use triton_vm::error::OpStackError::FailedU32Conversion;
122
123 use super::*;
124 use crate::U32_TO_USIZE_ERR;
125 use crate::rust_shadowing_helper_functions::list::insert_random_list;
126 use crate::rust_shadowing_helper_functions::list::list_set;
127 use crate::test_helpers::negative_test;
128 use crate::test_prelude::*;
129
130 impl Set {
131 fn set_up_initial_state(
132 &self,
133 list_length: usize,
134 index: usize,
135 list_pointer: BFieldElement,
136 element: Vec<BFieldElement>,
137 ) -> FunctionInitialState {
138 let mut memory = HashMap::default();
139 insert_random_list(&self.element_type, list_pointer, list_length, &mut memory);
140
141 let mut stack = self.init_stack_for_isolated_run();
142 stack.extend(element.into_iter().rev());
143 stack.push(list_pointer);
144 stack.push(bfe!(index));
145
146 FunctionInitialState { stack, memory }
147 }
148 }
149
150 impl Function for Set {
151 fn rust_shadow(
152 &self,
153 stack: &mut Vec<BFieldElement>,
154 memory: &mut HashMap<BFieldElement, BFieldElement>,
155 ) {
156 let index = pop_encodable::<u32>(stack);
157 let list_pointer = stack.pop().unwrap();
158 let element = (0..self.element_type.stack_size())
159 .map(|_| stack.pop().unwrap())
160 .collect_vec();
161
162 let index = index.try_into().expect(U32_TO_USIZE_ERR);
163 list_set(list_pointer, index, element, memory);
164 }
165
166 fn pseudorandom_initial_state(
167 &self,
168 seed: [u8; 32],
169 bench_case: Option<BenchmarkCase>,
170 ) -> FunctionInitialState {
171 let mut rng = StdRng::from_seed(seed);
172 let (list_length, index, list_pointer) = Get::random_len_idx_ptr(bench_case, &mut rng);
173 let element = self.element_type.seeded_random_element(&mut rng);
174
175 self.set_up_initial_state(list_length, index, list_pointer, element)
176 }
177 }
178
179 #[test]
180 fn rust_shadow() {
181 for ty in [
182 DataType::Bool,
183 DataType::Bfe,
184 DataType::U32,
185 DataType::U64,
186 DataType::Xfe,
187 DataType::Digest,
188 ] {
189 ShadowedFunction::new(Set::new(ty)).test();
190 }
191 }
192
193 #[proptest]
194 fn out_of_bounds_access_crashes_vm(
195 #[strategy(0_usize..=1_000)] list_length: usize,
196 #[strategy(#list_length..1 << 32)] index: usize,
197 #[strategy(arb())] list_pointer: BFieldElement,
198 #[strategy(vec(arb(), 1))] element: Vec<BFieldElement>,
199 ) {
200 let set = Set::new(DataType::Bfe);
201 let initial_state = set.set_up_initial_state(list_length, index, list_pointer, element);
202 test_assertion_failure(
203 &ShadowedFunction::new(set),
204 initial_state.into(),
205 &[Set::INDEX_OUT_OF_BOUNDS_ERROR_ID],
206 );
207 }
208
209 #[proptest]
210 fn too_large_indices_crash_vm(
211 #[strategy(1_usize << 32..)] index: usize,
212 #[strategy(arb())] list_pointer: BFieldElement,
213 #[strategy(vec(arb(), 1))] element: Vec<BFieldElement>,
214 ) {
215 let list_length = 0;
216 let set = Set::new(DataType::Bfe);
217 let initial_state = set.set_up_initial_state(list_length, index, list_pointer, element);
218 let expected_error = InstructionError::OpStackError(FailedU32Conversion(bfe!(index)));
219 negative_test(
220 &ShadowedFunction::new(set),
221 initial_state.into(),
222 &[expected_error],
223 );
224 }
225
226 #[proptest(cases = 100)]
228 fn too_large_lists_crash_vm(
229 #[strategy(1_u64 << 22..1 << 32)] list_length: u64,
230 #[strategy((1 << 22) - 1..#list_length)] index: u64,
231 #[strategy(arb())] list_pointer: BFieldElement,
232 ) {
233 let mut memory = HashMap::default();
235 memory.insert(list_pointer, bfe!(list_length));
236
237 let tuple_ty = DataType::Tuple(vec![DataType::Bfe; 1 << 10]);
239 let set = Set::new(tuple_ty);
240
241 let mut stack = set.init_stack_for_isolated_run();
243 stack.push(list_pointer);
244 stack.push(bfe!(index));
245 let initial_state = AccessorInitialState { stack, memory };
246
247 test_assertion_failure(
248 &ShadowedFunction::new(set),
249 initial_state.into(),
250 &[Set::MEM_PAGE_ACCESS_VIOLATION_ERROR_ID],
251 );
252 }
253}
254
255#[cfg(test)]
256mod benches {
257 use super::*;
258 use crate::test_prelude::*;
259
260 #[test]
261 fn benchmark() {
262 ShadowedFunction::new(Set::new(DataType::Digest)).bench();
263 }
264}