tasm_lib/list/
swap_unchecked.rs

1use triton_vm::isa::op_stack::NUM_OP_STACK_REGISTERS;
2use triton_vm::prelude::*;
3
4use crate::list::LIST_METADATA_SIZE;
5use crate::prelude::*;
6
7#[derive(Debug, Clone, Eq, PartialEq, Hash)]
8pub struct SwapUnchecked {
9    element_type: DataType,
10}
11
12impl SwapUnchecked {
13    pub fn new(element_type: DataType) -> Self {
14        Self { element_type }
15    }
16}
17
18impl BasicSnippet for SwapUnchecked {
19    fn inputs(&self) -> Vec<(DataType, String)> {
20        let self_type = DataType::List(Box::new(self.element_type.to_owned()));
21
22        vec![
23            (self_type, "self".to_owned()),
24            (DataType::U32, "a".to_owned()),
25            (DataType::U32, "b".to_owned()),
26        ]
27    }
28
29    fn outputs(&self) -> Vec<(DataType, String)> {
30        vec![]
31    }
32
33    fn entrypoint(&self) -> String {
34        format!(
35            "tasmlib_list_swap_{}",
36            self.element_type.label_friendly_name()
37        )
38    }
39
40    fn code(&self, _library: &mut Library) -> Vec<LabelledInstruction> {
41        let metadata_size = LIST_METADATA_SIZE;
42        let element_size = self.element_type.stack_size();
43        assert!(
44            element_size + 2 < NUM_OP_STACK_REGISTERS,
45            "This implementation can only handle swap up to element size 13"
46        );
47
48        let mul_with_size = if element_size == 1 {
49            triton_asm!()
50        } else {
51            triton_asm!(
52                push {element_size}
53                mul
54            )
55        };
56
57        let get_offset_for_last_word_in_element = if element_size == 1 {
58            triton_asm!(
59                // _ i
60                addi { metadata_size } // _ i_offset_last_word
61            )
62        } else {
63            triton_asm!(
64                // _ i
65
66                {&mul_with_size}
67                // _ i_offset_internal
68
69                addi {metadata_size}
70                // _ i_offset
71
72                addi {element_size - 1}
73                // _ i_offset_last_word
74            )
75        };
76
77        triton_asm!(
78                // BEFORE: _ *list a b
79                // AFTER:  _
80                {self.entrypoint()}:
81
82                    // calculate *list[b]
83                    // _ *list a b
84
85                    {&get_offset_for_last_word_in_element}
86                    // _ *list a b_offset_last_word
87
88                    dup 2
89                    add
90                    // _ *list a *list[b]_last_word
91
92                    {&self.element_type.read_value_from_memory_leave_pointer()}
93                    // _ *list a [list[b]] (*list[b] - 1)
94
95                    addi 1
96                    // _ *list a [list[b]] *list[b]
97
98                    dup {element_size + 2}
99                    dup {element_size + 2}
100                    // _ *list a [list[b]] *list[b] *list a
101
102                    {&get_offset_for_last_word_in_element}
103                    // _ *list a [list[b]] *list[b] *list a_offset_last_word
104
105                    add
106                    // _ *list a [list[b]] *list[b] *list[a]_last_word
107
108                    {&self.element_type.read_value_from_memory_leave_pointer()}
109                    // _ *list a [list[b]] *list[b] [list[a]] (*list[a] - 1)
110
111                    addi 1
112                    // _ *list a [list[b]] *list[b] [list[a]] *list[a]
113
114                    swap {element_size + 1}
115                    // _ *list a [list[b]] *list[a] [list[a]] *list[b]
116
117                    {&self.element_type.write_value_to_memory_pop_pointer()}
118                    // _ *list a [list[b]] *list[a]
119
120                    // We leave pointer here, since it's more efficient to just
121                    // pop all garbage in one fell swoop at the end.
122                    {&self.element_type.write_value_to_memory_leave_pointer()}
123                    // _ *list a *some_pointer
124
125                    pop 3
126
127                    return
128        )
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use crate::empty_stack;
136    use crate::rust_shadowing_helper_functions::list::insert_random_list;
137    use crate::rust_shadowing_helper_functions::list::list_get;
138    use crate::rust_shadowing_helper_functions::list::list_set;
139    use crate::test_prelude::*;
140
141    impl SwapUnchecked {
142        fn initial_state(
143            &self,
144            list_pointer: BFieldElement,
145            list_length: usize,
146            a: usize,
147            b: usize,
148        ) -> AlgorithmInitialState {
149            let mut init_memory = HashMap::default();
150            insert_random_list(
151                &self.element_type,
152                list_pointer,
153                list_length,
154                &mut init_memory,
155            );
156
157            AlgorithmInitialState {
158                stack: [empty_stack(), bfe_vec![list_pointer, a, b]].concat(),
159                nondeterminism: NonDeterminism::default().with_ram(init_memory),
160            }
161        }
162    }
163
164    impl Algorithm for SwapUnchecked {
165        fn rust_shadow(
166            &self,
167            stack: &mut Vec<BFieldElement>,
168            memory: &mut HashMap<BFieldElement, BFieldElement>,
169            _: &NonDeterminism,
170        ) {
171            let b_index = stack.pop().unwrap().value() as usize;
172            let a_index = stack.pop().unwrap().value() as usize;
173            let list_pointer = stack.pop().unwrap();
174            let element_size = self.element_type.stack_size();
175
176            let a = list_get(list_pointer, a_index, memory, element_size);
177            let b = list_get(list_pointer, b_index, memory, element_size);
178            list_set(list_pointer, a_index, b, memory);
179            list_set(list_pointer, b_index, a, memory);
180        }
181
182        fn pseudorandom_initial_state(
183            &self,
184            seed: [u8; 32],
185            _: Option<BenchmarkCase>,
186        ) -> AlgorithmInitialState {
187            let mut rng = StdRng::from_seed(seed);
188            let list_pointer = rng.random();
189            let list_length = rng.random_range(1..200);
190            let a = rng.random_range(0..list_length);
191            let b = rng.random_range(0..list_length);
192            self.initial_state(list_pointer, list_length, a, b)
193        }
194
195        fn corner_case_initial_states(&self) -> Vec<AlgorithmInitialState> {
196            vec![
197                self.initial_state(bfe!(1), 5, 0, 0),
198                self.initial_state(bfe!(1), 1, 0, 0),
199                self.initial_state(bfe!(1), 2, 1, 1),
200            ]
201        }
202    }
203
204    #[test]
205    fn test() {
206        for data_type in [
207            DataType::Bfe,
208            DataType::Bool,
209            DataType::U64,
210            DataType::Xfe,
211            DataType::U128,
212            DataType::Digest,
213            DataType::Tuple(vec![DataType::Xfe, DataType::Xfe]),
214            DataType::Tuple(vec![DataType::Digest, DataType::U64]),
215            DataType::Tuple(vec![DataType::Xfe, DataType::Digest]),
216            DataType::Tuple(vec![DataType::Xfe, DataType::Xfe, DataType::Xfe]),
217            DataType::Tuple(vec![DataType::Digest, DataType::Digest]),
218            DataType::Tuple(vec![DataType::Xfe, DataType::Xfe, DataType::Digest]),
219            DataType::Tuple(vec![
220                DataType::Xfe,
221                DataType::Xfe,
222                DataType::Xfe,
223                DataType::Xfe,
224            ]),
225            DataType::Tuple(vec![DataType::Digest, DataType::Digest, DataType::Xfe]),
226        ] {
227            ShadowedAlgorithm::new(SwapUnchecked::new(data_type)).test();
228        }
229    }
230}
231
232#[cfg(test)]
233mod benches {
234    use super::*;
235    use crate::test_prelude::*;
236
237    #[test]
238    fn benchmark() {
239        ShadowedAlgorithm::new(SwapUnchecked::new(DataType::Xfe)).bench();
240    }
241}