tasm_lib/list/
split_off.rs

1use std::collections::HashMap;
2
3use triton_vm::prelude::*;
4
5use crate::list::get::Get;
6use crate::prelude::*;
7use crate::traits::basic_snippet::Reviewer;
8use crate::traits::basic_snippet::SignOffFingerprint;
9
10/// Mutates an existing vector by reducing its length to `at` and returns the
11/// new vector. Mimics [`Vec::split_off`].
12///
13/// Only supports lists with [statically sized](BFieldCodec::static_length)
14/// elements.
15///
16/// ### Behavior
17///
18/// ```text
19/// BEFORE: _ *list [at: u32]
20/// AFTER:  _ *new_list
21/// ```
22///
23/// ### Preconditions
24///
25/// - the argument `*list` points to a properly [`BFieldCodec`]-encoded list
26/// - all input arguments are properly [`BFieldCodec`] encoded
27///
28/// ### Postconditions
29///
30/// - `*new_list` points to a properly [`BFieldCodec`]-encoded list
31#[derive(Debug, Clone, Eq, PartialEq, Hash)]
32pub struct SplitOff {
33    element_type: DataType,
34}
35
36impl SplitOff {
37    pub const OUT_OF_BOUNDS_ERROR_ID: i128 = 80;
38
39    /// # Panics
40    ///
41    /// Panics if the element has [dynamic length][BFieldCodec::static_length], or
42    /// if the static length is 0.
43    pub fn new(element_type: DataType) -> Self {
44        Get::assert_element_type_is_supported(&element_type);
45
46        Self { element_type }
47    }
48
49    fn self_type(&self) -> DataType {
50        DataType::List(Box::new(self.element_type.to_owned()))
51    }
52}
53
54impl BasicSnippet for SplitOff {
55    fn inputs(&self) -> Vec<(DataType, String)> {
56        vec![
57            (self.self_type(), "self".to_owned()),
58            (DataType::U32, "at".to_owned()),
59        ]
60    }
61
62    fn outputs(&self) -> Vec<(DataType, String)> {
63        vec![(self.self_type(), "new_list".to_owned())]
64    }
65
66    fn entrypoint(&self) -> String {
67        let element_type = self.element_type.label_friendly_name();
68        format!("tasmlib_list_split_off_{element_type}")
69    }
70
71    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
72        let dyn_malloc = library.import(Box::new(DynMalloc));
73        let mem_cpy = library.import(Box::new(MemCpy));
74
75        triton_asm!(
76            // BEFORE: _ *list at
77            // AFTER:  _ *new_list
78            {self.entrypoint()}:
79                /* get original length */
80                pick 1
81                read_mem 1
82                addi 1              // _ at original_length *list
83
84                /* assert `at` is in bounds */
85                dup 2
86                dup 2               // _ at original_length *list at original_length
87                lt
88                push 0
89                eq                  // _ at original_length *list (at <= original_length)
90                assert error_id {Self::OUT_OF_BOUNDS_ERROR_ID}
91                                    // _ at original_length *list
92
93                /* write new length of original list */
94                dup 2
95                place 1             // _ at original_length at *list
96                write_mem 1         // _ at original_length (*list+1)
97
98                /* prepare mem_cpy: *read_source */
99                dup 2               // _ at original_length (*list+1) at
100                push {self.element_type.stack_size()}
101                mul
102                add                 // _ at original_length (*list + 1 + at*element_size)
103                                    // _ at original_length *read_source
104
105                /* allocate new list and set its length */
106                pick 2
107                push -1
108                mul                 // _ original_length *read_source (-at)
109                pick 2
110                add                 // _ *read_source (original_length - at)
111                                    // _ *read_source new_len
112
113                dup 0
114                call {dyn_malloc}   // _ *read_source new_len new_len *new_list
115                dup 0
116                place 4             // _ *new_list *read_source new_len new_len *new_list
117                write_mem 1         // _ *new_list *read_source new_len (*new_list + 1)
118                                    // _ *new_list *read_source new_len *write_dest
119
120                /* prepare mem_cpy: num_words */
121                pick 1
122                push {self.element_type.stack_size()}
123                mul                 // _ *new_list *read_source *write_dest (new_len * element_size)
124                                    // _ *new_list *read_source *write_dest num_words
125
126                call {mem_cpy}      // _ *new_list
127                return
128        )
129    }
130
131    fn sign_offs(&self) -> HashMap<Reviewer, SignOffFingerprint> {
132        let mut sign_offs = HashMap::new();
133        match self.element_type.stack_size() {
134            1 => _ = sign_offs.insert(Reviewer("ferdinand"), 0x6740c2eb354b959d.into()),
135            2 => _ = sign_offs.insert(Reviewer("ferdinand"), 0x79cb11ba6120c8eb.into()),
136            3 => _ = sign_offs.insert(Reviewer("ferdinand"), 0x4299b2493e810d49.into()),
137            4 => _ = sign_offs.insert(Reviewer("ferdinand"), 0x9b012d7b60022f84.into()),
138            5 => _ = sign_offs.insert(Reviewer("ferdinand"), 0x8b601b3383a3e967.into()),
139            _ => (),
140        }
141
142        sign_offs
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use proptest::strategy::Union;
149
150    use super::*;
151    use crate::U32_TO_USIZE_ERR;
152    use crate::list::LIST_METADATA_SIZE;
153    use crate::rust_shadowing_helper_functions::dyn_malloc::dynamic_allocator;
154    use crate::rust_shadowing_helper_functions::list::insert_random_list;
155    use crate::rust_shadowing_helper_functions::list::list_set_length;
156    use crate::rust_shadowing_helper_functions::list::load_list_unstructured;
157    use crate::test_helpers::test_assertion_failure;
158    use crate::test_prelude::*;
159
160    impl SplitOff {
161        fn set_up_initial_state(
162            &self,
163            list_length: usize,
164            at: usize,
165            list_pointer: BFieldElement,
166        ) -> FunctionInitialState {
167            let mut memory = HashMap::default();
168            insert_random_list(&self.element_type, list_pointer, list_length, &mut memory);
169
170            let mut stack = self.init_stack_for_isolated_run();
171            stack.push(list_pointer);
172            stack.push(bfe!(at));
173
174            FunctionInitialState { stack, memory }
175        }
176    }
177
178    impl Function for SplitOff {
179        fn rust_shadow(
180            &self,
181            stack: &mut Vec<BFieldElement>,
182            memory: &mut HashMap<BFieldElement, BFieldElement>,
183        ) {
184            let at = pop_encodable::<u32>(stack)
185                .try_into()
186                .expect(U32_TO_USIZE_ERR);
187            let list_pointer = stack.pop().unwrap();
188
189            let mut list =
190                load_list_unstructured(self.element_type.stack_size(), list_pointer, memory);
191            let new_list = list.split_off(at);
192
193            let new_list_pointer = dynamic_allocator(memory);
194            list_set_length(list_pointer, list.len(), memory);
195            list_set_length(new_list_pointer, new_list.len(), memory);
196
197            for (offset, word) in (LIST_METADATA_SIZE..).zip(new_list.into_iter().flatten()) {
198                memory.insert(new_list_pointer + bfe!(offset), word);
199            }
200            stack.push(new_list_pointer);
201        }
202
203        fn pseudorandom_initial_state(
204            &self,
205            seed: [u8; 32],
206            bench_case: Option<BenchmarkCase>,
207        ) -> FunctionInitialState {
208            let mut rng = StdRng::from_seed(seed);
209            let (list_length, at) = match bench_case {
210                Some(BenchmarkCase::CommonCase) => (100, 50),
211                Some(BenchmarkCase::WorstCase) => (1000, 0),
212                None => {
213                    let list_length = rng.random_range(1..1000);
214                    (list_length, rng.random_range(0..list_length))
215                }
216            };
217            let list_pointer = rng.random();
218
219            self.set_up_initial_state(list_length, at, list_pointer)
220        }
221    }
222
223    #[test]
224    fn rust_shadow() {
225        for element_type in [
226            DataType::U32,
227            DataType::U64,
228            DataType::Xfe,
229            DataType::U128,
230            DataType::Digest,
231        ] {
232            ShadowedFunction::new(SplitOff::new(element_type)).test()
233        }
234    }
235
236    #[proptest]
237    fn out_of_bounds_index_crashes_vm(
238        #[strategy(Union::new(
239            [DataType::U32, DataType::U64, DataType::Xfe, DataType::Digest].map(Just)
240        ))]
241        element_type: DataType,
242        #[strategy(0_usize..100)] list_length: usize,
243        #[strategy(#list_length..1 << 30)] at: usize,
244        #[strategy(arb())] list_pointer: BFieldElement,
245    ) {
246        let snippet = SplitOff::new(element_type);
247        let initial_state = snippet.set_up_initial_state(list_length, at, list_pointer);
248
249        test_assertion_failure(
250            &ShadowedFunction::new(snippet),
251            initial_state.into(),
252            &[SplitOff::OUT_OF_BOUNDS_ERROR_ID],
253        );
254    }
255}
256
257#[cfg(test)]
258mod benches {
259    use super::*;
260    use crate::test_prelude::*;
261
262    #[test]
263    fn benchmark() {
264        ShadowedFunction::new(SplitOff::new(DataType::Xfe)).bench();
265    }
266}