1use triton_vm::prelude::*;
2
3use crate::list::get::Get;
4use crate::list::length::Length;
5use crate::prelude::*;
6
7#[derive(Debug, Clone, Eq, PartialEq, Hash)]
28pub struct Set {
29 element_type: DataType,
30}
31
32impl Set {
33 pub const INDEX_OUT_OF_BOUNDS_ERROR_ID: i128 = 390;
34
35 pub const MEM_PAGE_ACCESS_VIOLATION_ERROR_ID: i128 = 391;
38
39 pub fn new(element_type: DataType) -> Self {
44 Get::assert_element_type_is_supported(&element_type);
45
46 Self { element_type }
47 }
48}
49
50impl BasicSnippet for Set {
51 fn parameters(&self) -> Vec<(DataType, String)> {
52 let element_type = self.element_type.clone();
53 let list_type = DataType::List(Box::new(element_type.clone()));
54 let index_type = DataType::U32;
55
56 vec![
57 (element_type, "element".to_string()),
58 (list_type, "*list".to_string()),
59 (index_type, "index".to_string()),
60 ]
61 }
62
63 fn return_values(&self) -> Vec<(DataType, String)> {
64 vec![]
65 }
66
67 fn entrypoint(&self) -> String {
68 let element_type = self.element_type.label_friendly_name();
69 format!("tasmlib_list_set_element___{element_type}")
70 }
71
72 fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
73 let list_length = library.import(Box::new(Length));
74 let mul_with_element_size = match self.element_type.stack_size() {
75 1 => triton_asm!(), n => triton_asm!(push {n} mul),
77 };
78 let add_element_size_minus_1 = match self.element_type.stack_size() {
79 1 => triton_asm!(), n => triton_asm!(addi {n - 1}),
81 };
82
83 triton_asm!(
84 {self.entrypoint()}:
87 dup 1
89 call {list_length} dup 1
91 lt assert error_id {Self::INDEX_OUT_OF_BOUNDS_ERROR_ID}
93 {&mul_with_element_size}
96 addi 1 dup 0
101 {&add_element_size_minus_1}
102 split
104 pop 1
105 push 0
106 eq
107 assert error_id {Self::MEM_PAGE_ACCESS_VIOLATION_ERROR_ID}
108 add {&self.element_type.write_value_to_memory_pop_pointer()}
112 return
114 )
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use proptest::collection::vec;
121
122 use super::*;
123 use crate::U32_TO_USIZE_ERR;
124 use crate::rust_shadowing_helper_functions::list::insert_random_list;
125 use crate::rust_shadowing_helper_functions::list::list_set;
126 use crate::test_prelude::*;
127
128 impl Set {
129 fn set_up_initial_state(
130 &self,
131 list_length: usize,
132 index: usize,
133 list_pointer: BFieldElement,
134 element: Vec<BFieldElement>,
135 ) -> FunctionInitialState {
136 let mut memory = HashMap::default();
137 insert_random_list(&self.element_type, list_pointer, list_length, &mut memory);
138
139 let mut stack = self.init_stack_for_isolated_run();
140 stack.extend(element.into_iter().rev());
141 stack.push(list_pointer);
142 stack.push(bfe!(index));
143
144 FunctionInitialState { stack, memory }
145 }
146 }
147
148 impl Function for Set {
149 fn rust_shadow(
150 &self,
151 stack: &mut Vec<BFieldElement>,
152 memory: &mut HashMap<BFieldElement, BFieldElement>,
153 ) -> Result<(), RustShadowError> {
154 let index = pop_encodable::<u32>(stack)?;
155 let list_pointer = stack.pop().ok_or(RustShadowError::StackUnderflow)?;
156 let element = (0..self.element_type.stack_size())
157 .map(|_| stack.pop().ok_or(RustShadowError::StackUnderflow))
158 .try_collect()?;
159
160 let index = index.try_into().expect(U32_TO_USIZE_ERR);
161 list_set(list_pointer, index, element, memory)?;
162
163 Ok(())
164 }
165
166 fn pseudorandom_initial_state(
167 &self,
168 seed: [u8; 32],
169 bench_case: Option<BenchmarkCase>,
170 ) -> FunctionInitialState {
171 let mut rng = StdRng::from_seed(seed);
172 let (list_length, index, list_pointer) = Get::random_len_idx_ptr(bench_case, &mut rng);
173 let element = self.element_type.seeded_random_element(&mut rng);
174
175 self.set_up_initial_state(list_length, index, list_pointer, element)
176 }
177 }
178
179 #[macro_rules_attr::apply(test)]
180 fn rust_shadow() {
181 for ty in [
182 DataType::Bool,
183 DataType::Bfe,
184 DataType::U32,
185 DataType::U64,
186 DataType::Xfe,
187 DataType::Digest,
188 ] {
189 ShadowedFunction::new(Set::new(ty)).test();
190 }
191 }
192
193 #[macro_rules_attr::apply(proptest)]
194 fn out_of_bounds_access_crashes_vm(
195 #[strategy(0_usize..=1_000)] list_length: usize,
196 #[strategy(#list_length..=u32::MAX.try_into().unwrap())] index: usize,
197 #[strategy(arb())] list_pointer: BFieldElement,
198 #[strategy(vec(arb(), 1))] element: Vec<BFieldElement>,
199 ) {
200 let set = Set::new(DataType::Bfe);
201 let initial_state = set.set_up_initial_state(list_length, index, list_pointer, element);
202 test_assertion_failure(
203 &ShadowedFunction::new(set),
204 initial_state.into(),
205 &[Set::INDEX_OUT_OF_BOUNDS_ERROR_ID],
206 );
207 }
208
209 #[test_strategy::proptest]
211 #[cfg(not(target_arch = "wasm32"))]
212 fn too_large_indices_crash_vm(
213 #[strategy(1_usize << 32..)] index: usize,
214 #[strategy(arb())] list_pointer: BFieldElement,
215 #[strategy(vec(arb(), 1))] element: Vec<BFieldElement>,
216 ) {
217 use triton_vm::error::OpStackError::FailedU32Conversion;
218
219 let list_length = 0;
220 let set = Set::new(DataType::Bfe);
221 let initial_state = set.set_up_initial_state(list_length, index, list_pointer, element);
222 let expected_error = InstructionError::OpStackError(FailedU32Conversion(bfe!(index)));
223 crate::test_helpers::negative_test(
224 &ShadowedFunction::new(set),
225 initial_state.into(),
226 &[expected_error],
227 );
228 }
229
230 #[macro_rules_attr::apply(proptest(cases = 100))]
232 fn too_large_lists_crash_vm(
233 #[strategy(1_u64 << 22..1 << 32)] list_length: u64,
234 #[strategy((1 << 22) - 1..#list_length)] index: u64,
235 #[strategy(arb())] list_pointer: BFieldElement,
236 ) {
237 let mut memory = HashMap::default();
239 memory.insert(list_pointer, bfe!(list_length));
240
241 let tuple_ty = DataType::Tuple(vec![DataType::Bfe; 1 << 10]);
243 let set = Set::new(tuple_ty);
244
245 let mut stack = set.init_stack_for_isolated_run();
247 stack.push(list_pointer);
248 stack.push(bfe!(index));
249 let initial_state = AccessorInitialState { stack, memory };
250
251 test_assertion_failure(
252 &ShadowedFunction::new(set),
253 initial_state.into(),
254 &[Set::MEM_PAGE_ACCESS_VIOLATION_ERROR_ID],
255 );
256 }
257}
258
259#[cfg(test)]
260mod benches {
261 use super::*;
262 use crate::test_prelude::*;
263
264 #[macro_rules_attr::apply(test)]
265 fn benchmark() {
266 ShadowedFunction::new(Set::new(DataType::Digest)).bench();
267 }
268}