tasm_lib/rust_shadowing_helper_functions/
list.rs

1use std::collections::HashMap;
2
3use num::Zero;
4use num_traits::ConstOne;
5use triton_vm::prelude::*;
6use twenty_first::math::other::random_elements;
7
8use crate::U32_TO_USIZE_ERR;
9use crate::USIZE_TO_U64_ERR;
10use crate::list::LIST_METADATA_SIZE;
11use crate::memory::dyn_malloc::DYN_MALLOC_PAGE_SIZE;
12use crate::prelude::*;
13
14/// Load a list from memory returning each element as a list of `BFieldElement`s.
15pub fn load_list_unstructured(
16    element_size: usize,
17    list_pointer: BFieldElement,
18    memory: &HashMap<BFieldElement, BFieldElement>,
19) -> Vec<Vec<BFieldElement>> {
20    let list_length: usize = memory[&list_pointer].value().try_into().unwrap();
21
22    let mut element_pointer = list_pointer + BFieldElement::new(LIST_METADATA_SIZE as u64);
23
24    let mut ret = Vec::with_capacity(list_length);
25    for i in 0..list_length {
26        ret.push(vec![]);
27        for _ in 0..element_size {
28            ret[i].push(memory[&element_pointer]);
29            element_pointer.increment();
30        }
31    }
32
33    ret
34}
35
36/// Load a list from memory. Elements must be of `Copy` type.
37pub fn load_list_with_copy_elements<const ELEMENT_SIZE: usize>(
38    list_pointer: BFieldElement,
39    memory: &HashMap<BFieldElement, BFieldElement>,
40) -> Vec<[BFieldElement; ELEMENT_SIZE]> {
41    let list_length: usize = memory[&list_pointer].value().try_into().unwrap();
42
43    let mut element_pointer = list_pointer + BFieldElement::new(LIST_METADATA_SIZE as u64);
44
45    let mut ret = Vec::with_capacity(list_length);
46    for i in 0..list_length {
47        ret.push([BFieldElement::zero(); ELEMENT_SIZE]);
48        for j in 0..ELEMENT_SIZE {
49            ret[i][j] = memory[&element_pointer];
50            element_pointer.increment();
51        }
52    }
53
54    ret
55}
56
57pub fn list_insert<T: BFieldCodec>(
58    list_pointer: BFieldElement,
59    vector: Vec<T>,
60    memory: &mut HashMap<BFieldElement, BFieldElement>,
61) {
62    list_new(list_pointer, memory);
63
64    for element in vector {
65        list_push(list_pointer, element.encode(), memory);
66    }
67}
68
69pub fn insert_random_list(
70    element_type: &DataType,
71    list_pointer: BFieldElement,
72    list_length: usize,
73    memory: &mut HashMap<BFieldElement, BFieldElement>,
74) {
75    let list = element_type.random_list(&mut rand::rng(), list_length);
76    let indexed_list = list
77        .into_iter()
78        .enumerate()
79        .map(|(i, v)| (list_pointer + bfe!(i), v));
80    memory.extend(indexed_list);
81}
82
83// TODO: Get rid of this stupid "helper" function
84pub fn untyped_insert_random_list(
85    list_pointer: BFieldElement,
86    list_length: usize,
87    memory: &mut HashMap<BFieldElement, BFieldElement>,
88    element_length: usize,
89) {
90    list_new(list_pointer, memory);
91    for _ in 0..list_length {
92        let random_element: Vec<BFieldElement> = random_elements(element_length);
93        list_push(list_pointer, random_element, memory);
94    }
95}
96
97pub fn list_new(list_pointer: BFieldElement, memory: &mut HashMap<BFieldElement, BFieldElement>) {
98    memory.insert(list_pointer, BFieldElement::zero());
99}
100
101/// Push the given element to the pointed-to list.
102///
103/// Only supports lists with statically sized elements.
104///
105/// # Panics
106///
107/// Panics if the pointed-to list is incorrectly encoded.
108pub fn list_push(
109    list_pointer: BFieldElement,
110    value: Vec<BFieldElement>,
111    memory: &mut HashMap<BFieldElement, BFieldElement>,
112) {
113    let list_length = memory
114        .get_mut(&list_pointer)
115        .expect("list must be initialized");
116    let len = list_length.value();
117    list_length.increment();
118
119    let element_size: u64 = value.len().try_into().expect(USIZE_TO_U64_ERR);
120    let list_metadata_size: u64 = LIST_METADATA_SIZE.try_into().expect(USIZE_TO_U64_ERR);
121    let highest_access_index = list_metadata_size + element_size * (len + 1);
122    assert!(highest_access_index < DYN_MALLOC_PAGE_SIZE);
123
124    for (i, word) in (0..).zip(value) {
125        let word_offset = bfe!(list_metadata_size + element_size * len + i);
126        memory.insert(list_pointer + word_offset, word);
127    }
128}
129
130/// Pop an element from the pointed-to list.
131///
132/// Only supports lists with statically sized elements.
133///
134/// # Panics
135///
136/// Panics if the pointed-to list is empty, or if the list is incorrectly
137/// encoded.
138pub fn list_pop(
139    list_pointer: BFieldElement,
140    memory: &mut HashMap<BFieldElement, BFieldElement>,
141    element_length: usize,
142) -> Vec<BFieldElement> {
143    let list_length = memory
144        .get_mut(&list_pointer)
145        .expect("list must be initialized");
146    assert_ne!(0, list_length.value(), "list must not be empty");
147    list_length.decrement();
148    let last_item_index = list_length.value();
149
150    let element_length: u64 = element_length.try_into().expect(USIZE_TO_U64_ERR);
151    let read_word = |i| {
152        let word_offset = bfe!(LIST_METADATA_SIZE) + bfe!(element_length * last_item_index + i);
153        let word_index = list_pointer + bfe!(word_offset);
154        memory[&word_index]
155    };
156
157    (0..element_length).map(read_word).collect()
158}
159
160/// A pointer to the `i`th element in the list, as well as the size of that
161/// element.
162///
163/// Supports both, lists with statically _and_ lists with dynamically sized
164/// elements.
165///
166/// # Panics
167///
168/// Panics if the `index` is out of bounds, or if the pointed-to-list is
169/// incorrectly encoded.
170pub fn list_pointer_to_elem_pointer(
171    list_pointer: BFieldElement,
172    index: usize,
173    memory: &HashMap<BFieldElement, BFieldElement>,
174    element_type: &DataType,
175) -> (usize, BFieldElement) {
176    let list_len = list_get_length(list_pointer, memory);
177    assert!(index < list_len, "out of bounds: {index} >= {list_len}");
178
179    if let Some(element_size) = element_type.static_length() {
180        let elem_ptr = list_pointer + bfe!(LIST_METADATA_SIZE + index * element_size);
181        return (element_size, elem_ptr);
182    }
183
184    let mut elem_pointer = list_pointer + bfe!(LIST_METADATA_SIZE);
185    for _ in 0..index {
186        elem_pointer += memory[&elem_pointer] + BFieldElement::ONE;
187    }
188    let elem_size = usize::try_from(memory[&elem_pointer].value()).unwrap();
189    (elem_size, elem_pointer + BFieldElement::ONE)
190}
191
192/// Read an element from a list.
193///
194/// Only supports lists with statically sized elements.
195///
196/// # Panics
197///
198/// Panics if
199/// - the `index` is out of bounds, or
200/// - the element that is to be read resides outside the list`s
201///   [memory page][crate::memory], or
202/// - the pointed-to-list is incorrectly encoded into `memory`.
203pub fn list_get(
204    list_pointer: BFieldElement,
205    index: usize,
206    memory: &HashMap<BFieldElement, BFieldElement>,
207    element_size: usize,
208) -> Vec<BFieldElement> {
209    let list_len = list_get_length(list_pointer, memory);
210    assert!(index < list_len, "out of bounds: {index} >= {list_len}");
211
212    let highest_access_index = LIST_METADATA_SIZE + element_size * (index + 1);
213    assert!(u64::try_from(highest_access_index).expect(USIZE_TO_U64_ERR) < DYN_MALLOC_PAGE_SIZE);
214
215    let read_word = |i| {
216        let word_offset = LIST_METADATA_SIZE + element_size * index + i;
217        let word_index = list_pointer + bfe!(word_offset);
218        memory[&word_index]
219    };
220
221    (0..element_size).map(read_word).collect()
222}
223
224/// Write an element to a list.
225///
226/// Only supports lists with statically sized elements.
227///
228/// # Panics
229///
230/// Panics if
231/// - the `index` is out of bounds, or
232/// - the element that is to be read resides outside the list`s
233///   [memory page][crate::memory], or
234/// - the pointed-to-list is incorrectly encoded into `memory`.
235pub fn list_set(
236    list_pointer: BFieldElement,
237    index: usize,
238    value: Vec<BFieldElement>,
239    memory: &mut HashMap<BFieldElement, BFieldElement>,
240) {
241    let list_len = list_get_length(list_pointer, memory);
242    assert!(index < list_len, "out of bounds: {index} >= {list_len}");
243
244    let element_size = value.len();
245    let highest_access_index = LIST_METADATA_SIZE + element_size * (index + 1);
246    assert!(u64::try_from(highest_access_index).expect(USIZE_TO_U64_ERR) < DYN_MALLOC_PAGE_SIZE);
247
248    for (i, word) in value.into_iter().enumerate() {
249        let word_offset = LIST_METADATA_SIZE + element_size * index + i;
250        let word_index = list_pointer + bfe!(word_offset);
251        memory.insert(word_index, word);
252    }
253}
254
255pub fn list_get_length(
256    list_pointer: BFieldElement,
257    memory: &HashMap<BFieldElement, BFieldElement>,
258) -> usize {
259    let length: u32 = memory[&list_pointer].value().try_into().unwrap();
260
261    length.try_into().expect(U32_TO_USIZE_ERR)
262}
263
264pub fn list_set_length(
265    list_pointer: BFieldElement,
266    new_length: usize,
267    memory: &mut HashMap<BFieldElement, BFieldElement>,
268) {
269    memory.insert(list_pointer, bfe!(new_length));
270}
271
272#[cfg(test)]
273mod tests {
274    use proptest::prop_assert_eq;
275    use proptest_arbitrary_interop::arb;
276    use test_strategy::proptest;
277
278    use super::*;
279
280    #[test]
281    fn new_list_set_length() {
282        let mut memory = HashMap::default();
283        let list_pointer = BFieldElement::new(20);
284        list_new(list_pointer, &mut memory);
285        assert!(list_get_length(list_pointer, &memory).is_zero());
286        let new_length = 51;
287        list_set_length(list_pointer, new_length, &mut memory);
288        assert_eq!(new_length, list_get_length(list_pointer, &memory));
289    }
290
291    #[proptest]
292    fn element_pointer_from_list_pointer_on_static_list_with_static_length_items(
293        #[strategy(arb())] list: Vec<Digest>,
294        #[strategy(arb())] list_pointer: BFieldElement,
295    ) {
296        let indexed_list = list
297            .encode()
298            .into_iter()
299            .enumerate()
300            .map(|(i, v)| (list_pointer + bfe!(i), v));
301
302        let mut memory = HashMap::default();
303        memory.extend(indexed_list);
304
305        let data_type = DataType::Digest;
306        for (i, digest) in list.into_iter().enumerate() {
307            dbg!(i);
308            let (len, ptr) = list_pointer_to_elem_pointer(list_pointer, i, &memory, &data_type);
309            prop_assert_eq!(Digest::LEN, len);
310            prop_assert_eq!(digest.values()[0], memory[&ptr]);
311        }
312    }
313
314    #[proptest]
315    fn element_pointer_from_list_pointer_on_static_list_with_dyn_length_items(
316        #[strategy(arb())] list: Vec<Vec<BFieldElement>>,
317        #[strategy(arb())] list_pointer: BFieldElement,
318    ) {
319        let indexed_list = list
320            .encode()
321            .into_iter()
322            .enumerate()
323            .map(|(i, v)| (list_pointer + bfe!(i), v));
324
325        let mut memory = HashMap::default();
326        memory.extend(indexed_list);
327
328        let data_type = DataType::List(Box::new(DataType::Bfe));
329        for (i, inner_list) in list.into_iter().enumerate() {
330            dbg!(i);
331            let (len, ptr) = list_pointer_to_elem_pointer(list_pointer, i, &memory, &data_type);
332            prop_assert_eq!(inner_list.encode().len(), len);
333            prop_assert_eq!(bfe!(inner_list.len()), memory[&ptr]);
334        }
335    }
336}