tasm_lib/rust_shadowing_helper_functions/
list.rs1use std::collections::HashMap;
2
3use num::Zero;
4use num_traits::ConstOne;
5use triton_vm::prelude::*;
6use twenty_first::math::other::random_elements;
7
8use crate::U32_TO_USIZE_ERR;
9use crate::USIZE_TO_U64_ERR;
10use crate::list::LIST_METADATA_SIZE;
11use crate::memory::dyn_malloc::DYN_MALLOC_PAGE_SIZE;
12use crate::prelude::*;
13
14pub fn load_list_unstructured(
16 element_size: usize,
17 list_pointer: BFieldElement,
18 memory: &HashMap<BFieldElement, BFieldElement>,
19) -> Vec<Vec<BFieldElement>> {
20 let list_length: usize = memory[&list_pointer].value().try_into().unwrap();
21
22 let mut element_pointer = list_pointer + BFieldElement::new(LIST_METADATA_SIZE as u64);
23
24 let mut ret = Vec::with_capacity(list_length);
25 for i in 0..list_length {
26 ret.push(vec![]);
27 for _ in 0..element_size {
28 ret[i].push(memory[&element_pointer]);
29 element_pointer.increment();
30 }
31 }
32
33 ret
34}
35
36pub fn load_list_with_copy_elements<const ELEMENT_SIZE: usize>(
38 list_pointer: BFieldElement,
39 memory: &HashMap<BFieldElement, BFieldElement>,
40) -> Vec<[BFieldElement; ELEMENT_SIZE]> {
41 let list_length: usize = memory[&list_pointer].value().try_into().unwrap();
42
43 let mut element_pointer = list_pointer + BFieldElement::new(LIST_METADATA_SIZE as u64);
44
45 let mut ret = Vec::with_capacity(list_length);
46 for i in 0..list_length {
47 ret.push([BFieldElement::zero(); ELEMENT_SIZE]);
48 for j in 0..ELEMENT_SIZE {
49 ret[i][j] = memory[&element_pointer];
50 element_pointer.increment();
51 }
52 }
53
54 ret
55}
56
57pub fn list_insert<T: BFieldCodec>(
58 list_pointer: BFieldElement,
59 vector: Vec<T>,
60 memory: &mut HashMap<BFieldElement, BFieldElement>,
61) {
62 list_new(list_pointer, memory);
63
64 for element in vector {
65 list_push(list_pointer, element.encode(), memory);
66 }
67}
68
69pub fn insert_random_list(
70 element_type: &DataType,
71 list_pointer: BFieldElement,
72 list_length: usize,
73 memory: &mut HashMap<BFieldElement, BFieldElement>,
74) {
75 let list = element_type.random_list(&mut rand::rng(), list_length);
76 let indexed_list = list
77 .into_iter()
78 .enumerate()
79 .map(|(i, v)| (list_pointer + bfe!(i), v));
80 memory.extend(indexed_list);
81}
82
83pub fn untyped_insert_random_list(
85 list_pointer: BFieldElement,
86 list_length: usize,
87 memory: &mut HashMap<BFieldElement, BFieldElement>,
88 element_length: usize,
89) {
90 list_new(list_pointer, memory);
91 for _ in 0..list_length {
92 let random_element: Vec<BFieldElement> = random_elements(element_length);
93 list_push(list_pointer, random_element, memory);
94 }
95}
96
97pub fn list_new(list_pointer: BFieldElement, memory: &mut HashMap<BFieldElement, BFieldElement>) {
98 memory.insert(list_pointer, BFieldElement::zero());
99}
100
101pub fn list_push(
109 list_pointer: BFieldElement,
110 value: Vec<BFieldElement>,
111 memory: &mut HashMap<BFieldElement, BFieldElement>,
112) {
113 let list_length = memory
114 .get_mut(&list_pointer)
115 .expect("list must be initialized");
116 let len = list_length.value();
117 list_length.increment();
118
119 let element_size: u64 = value.len().try_into().expect(USIZE_TO_U64_ERR);
120 let list_metadata_size: u64 = LIST_METADATA_SIZE.try_into().expect(USIZE_TO_U64_ERR);
121 let highest_access_index = list_metadata_size + element_size * (len + 1);
122 assert!(highest_access_index < DYN_MALLOC_PAGE_SIZE);
123
124 for (i, word) in (0..).zip(value) {
125 let word_offset = bfe!(list_metadata_size + element_size * len + i);
126 memory.insert(list_pointer + word_offset, word);
127 }
128}
129
130pub fn list_pop(
139 list_pointer: BFieldElement,
140 memory: &mut HashMap<BFieldElement, BFieldElement>,
141 element_length: usize,
142) -> Vec<BFieldElement> {
143 let list_length = memory
144 .get_mut(&list_pointer)
145 .expect("list must be initialized");
146 assert_ne!(0, list_length.value(), "list must not be empty");
147 list_length.decrement();
148 let last_item_index = list_length.value();
149
150 let element_length: u64 = element_length.try_into().expect(USIZE_TO_U64_ERR);
151 let read_word = |i| {
152 let word_offset = bfe!(LIST_METADATA_SIZE) + bfe!(element_length * last_item_index + i);
153 let word_index = list_pointer + bfe!(word_offset);
154 memory[&word_index]
155 };
156
157 (0..element_length).map(read_word).collect()
158}
159
160pub fn list_pointer_to_elem_pointer(
171 list_pointer: BFieldElement,
172 index: usize,
173 memory: &HashMap<BFieldElement, BFieldElement>,
174 element_type: &DataType,
175) -> (usize, BFieldElement) {
176 let list_len = list_get_length(list_pointer, memory);
177 assert!(index < list_len, "out of bounds: {index} >= {list_len}");
178
179 if let Some(element_size) = element_type.static_length() {
180 let elem_ptr = list_pointer + bfe!(LIST_METADATA_SIZE + index * element_size);
181 return (element_size, elem_ptr);
182 }
183
184 let mut elem_pointer = list_pointer + bfe!(LIST_METADATA_SIZE);
185 for _ in 0..index {
186 elem_pointer += memory[&elem_pointer] + BFieldElement::ONE;
187 }
188 let elem_size = usize::try_from(memory[&elem_pointer].value()).unwrap();
189 (elem_size, elem_pointer + BFieldElement::ONE)
190}
191
192pub fn list_get(
204 list_pointer: BFieldElement,
205 index: usize,
206 memory: &HashMap<BFieldElement, BFieldElement>,
207 element_size: usize,
208) -> Vec<BFieldElement> {
209 let list_len = list_get_length(list_pointer, memory);
210 assert!(index < list_len, "out of bounds: {index} >= {list_len}");
211
212 let highest_access_index = LIST_METADATA_SIZE + element_size * (index + 1);
213 assert!(u64::try_from(highest_access_index).expect(USIZE_TO_U64_ERR) < DYN_MALLOC_PAGE_SIZE);
214
215 let read_word = |i| {
216 let word_offset = LIST_METADATA_SIZE + element_size * index + i;
217 let word_index = list_pointer + bfe!(word_offset);
218 memory[&word_index]
219 };
220
221 (0..element_size).map(read_word).collect()
222}
223
224pub fn list_set(
236 list_pointer: BFieldElement,
237 index: usize,
238 value: Vec<BFieldElement>,
239 memory: &mut HashMap<BFieldElement, BFieldElement>,
240) {
241 let list_len = list_get_length(list_pointer, memory);
242 assert!(index < list_len, "out of bounds: {index} >= {list_len}");
243
244 let element_size = value.len();
245 let highest_access_index = LIST_METADATA_SIZE + element_size * (index + 1);
246 assert!(u64::try_from(highest_access_index).expect(USIZE_TO_U64_ERR) < DYN_MALLOC_PAGE_SIZE);
247
248 for (i, word) in value.into_iter().enumerate() {
249 let word_offset = LIST_METADATA_SIZE + element_size * index + i;
250 let word_index = list_pointer + bfe!(word_offset);
251 memory.insert(word_index, word);
252 }
253}
254
255pub fn list_get_length(
256 list_pointer: BFieldElement,
257 memory: &HashMap<BFieldElement, BFieldElement>,
258) -> usize {
259 let length: u32 = memory[&list_pointer].value().try_into().unwrap();
260
261 length.try_into().expect(U32_TO_USIZE_ERR)
262}
263
264pub fn list_set_length(
265 list_pointer: BFieldElement,
266 new_length: usize,
267 memory: &mut HashMap<BFieldElement, BFieldElement>,
268) {
269 memory.insert(list_pointer, bfe!(new_length));
270}
271
272#[cfg(test)]
273mod tests {
274 use proptest::prop_assert_eq;
275 use proptest_arbitrary_interop::arb;
276 use test_strategy::proptest;
277
278 use super::*;
279
280 #[test]
281 fn new_list_set_length() {
282 let mut memory = HashMap::default();
283 let list_pointer = BFieldElement::new(20);
284 list_new(list_pointer, &mut memory);
285 assert!(list_get_length(list_pointer, &memory).is_zero());
286 let new_length = 51;
287 list_set_length(list_pointer, new_length, &mut memory);
288 assert_eq!(new_length, list_get_length(list_pointer, &memory));
289 }
290
291 #[proptest]
292 fn element_pointer_from_list_pointer_on_static_list_with_static_length_items(
293 #[strategy(arb())] list: Vec<Digest>,
294 #[strategy(arb())] list_pointer: BFieldElement,
295 ) {
296 let indexed_list = list
297 .encode()
298 .into_iter()
299 .enumerate()
300 .map(|(i, v)| (list_pointer + bfe!(i), v));
301
302 let mut memory = HashMap::default();
303 memory.extend(indexed_list);
304
305 let data_type = DataType::Digest;
306 for (i, digest) in list.into_iter().enumerate() {
307 dbg!(i);
308 let (len, ptr) = list_pointer_to_elem_pointer(list_pointer, i, &memory, &data_type);
309 prop_assert_eq!(Digest::LEN, len);
310 prop_assert_eq!(digest.values()[0], memory[&ptr]);
311 }
312 }
313
314 #[proptest]
315 fn element_pointer_from_list_pointer_on_static_list_with_dyn_length_items(
316 #[strategy(arb())] list: Vec<Vec<BFieldElement>>,
317 #[strategy(arb())] list_pointer: BFieldElement,
318 ) {
319 let indexed_list = list
320 .encode()
321 .into_iter()
322 .enumerate()
323 .map(|(i, v)| (list_pointer + bfe!(i), v));
324
325 let mut memory = HashMap::default();
326 memory.extend(indexed_list);
327
328 let data_type = DataType::List(Box::new(DataType::Bfe));
329 for (i, inner_list) in list.into_iter().enumerate() {
330 dbg!(i);
331 let (len, ptr) = list_pointer_to_elem_pointer(list_pointer, i, &memory, &data_type);
332 prop_assert_eq!(inner_list.encode().len(), len);
333 prop_assert_eq!(bfe!(inner_list.len()), memory[&ptr]);
334 }
335 }
336}