1use triton_vm::isa::op_stack::NUM_OP_STACK_REGISTERS;
2use triton_vm::prelude::*;
3
4use crate::list::LIST_METADATA_SIZE;
5use crate::prelude::*;
6
7#[derive(Debug, Clone, Eq, PartialEq, Hash)]
8pub struct SwapUnchecked {
9 element_type: DataType,
10}
11
12impl SwapUnchecked {
13 pub fn new(element_type: DataType) -> Self {
14 Self { element_type }
15 }
16}
17
18impl BasicSnippet for SwapUnchecked {
19 fn inputs(&self) -> Vec<(DataType, String)> {
20 let self_type = DataType::List(Box::new(self.element_type.to_owned()));
21
22 vec![
23 (self_type, "self".to_owned()),
24 (DataType::U32, "a".to_owned()),
25 (DataType::U32, "b".to_owned()),
26 ]
27 }
28
29 fn outputs(&self) -> Vec<(DataType, String)> {
30 vec![]
31 }
32
33 fn entrypoint(&self) -> String {
34 format!(
35 "tasmlib_list_swap_{}",
36 self.element_type.label_friendly_name()
37 )
38 }
39
40 fn code(&self, _library: &mut Library) -> Vec<LabelledInstruction> {
41 let metadata_size = LIST_METADATA_SIZE;
42 let element_size = self.element_type.stack_size();
43 assert!(
44 element_size + 2 < NUM_OP_STACK_REGISTERS,
45 "This implementation can only handle swap up to element size 13"
46 );
47
48 let mul_with_size = if element_size == 1 {
49 triton_asm!()
50 } else {
51 triton_asm!(
52 push {element_size}
53 mul
54 )
55 };
56
57 let get_offset_for_last_word_in_element = if element_size == 1 {
58 triton_asm!(
59 addi { metadata_size } )
62 } else {
63 triton_asm!(
64 {&mul_with_size}
67 addi {metadata_size}
70 addi {element_size - 1}
73 )
75 };
76
77 triton_asm!(
78 {self.entrypoint()}:
81
82 {&get_offset_for_last_word_in_element}
86 dup 2
89 add
90 {&self.element_type.read_value_from_memory_leave_pointer()}
93 addi 1
96 dup {element_size + 2}
99 dup {element_size + 2}
100 {&get_offset_for_last_word_in_element}
103 add
106 {&self.element_type.read_value_from_memory_leave_pointer()}
109 addi 1
112 swap {element_size + 1}
115 {&self.element_type.write_value_to_memory_pop_pointer()}
118 {&self.element_type.write_value_to_memory_leave_pointer()}
123 pop 3
126
127 return
128 )
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135 use crate::empty_stack;
136 use crate::rust_shadowing_helper_functions::list::insert_random_list;
137 use crate::rust_shadowing_helper_functions::list::list_get;
138 use crate::rust_shadowing_helper_functions::list::list_set;
139 use crate::test_prelude::*;
140
141 impl SwapUnchecked {
142 fn initial_state(
143 &self,
144 list_pointer: BFieldElement,
145 list_length: usize,
146 a: usize,
147 b: usize,
148 ) -> AlgorithmInitialState {
149 let mut init_memory = HashMap::default();
150 insert_random_list(
151 &self.element_type,
152 list_pointer,
153 list_length,
154 &mut init_memory,
155 );
156
157 AlgorithmInitialState {
158 stack: [empty_stack(), bfe_vec![list_pointer, a, b]].concat(),
159 nondeterminism: NonDeterminism::default().with_ram(init_memory),
160 }
161 }
162 }
163
164 impl Algorithm for SwapUnchecked {
165 fn rust_shadow(
166 &self,
167 stack: &mut Vec<BFieldElement>,
168 memory: &mut HashMap<BFieldElement, BFieldElement>,
169 _: &NonDeterminism,
170 ) {
171 let b_index = stack.pop().unwrap().value() as usize;
172 let a_index = stack.pop().unwrap().value() as usize;
173 let list_pointer = stack.pop().unwrap();
174 let element_size = self.element_type.stack_size();
175
176 let a = list_get(list_pointer, a_index, memory, element_size);
177 let b = list_get(list_pointer, b_index, memory, element_size);
178 list_set(list_pointer, a_index, b, memory);
179 list_set(list_pointer, b_index, a, memory);
180 }
181
182 fn pseudorandom_initial_state(
183 &self,
184 seed: [u8; 32],
185 _: Option<BenchmarkCase>,
186 ) -> AlgorithmInitialState {
187 let mut rng = StdRng::from_seed(seed);
188 let list_pointer = rng.random();
189 let list_length = rng.random_range(1..200);
190 let a = rng.random_range(0..list_length);
191 let b = rng.random_range(0..list_length);
192 self.initial_state(list_pointer, list_length, a, b)
193 }
194
195 fn corner_case_initial_states(&self) -> Vec<AlgorithmInitialState> {
196 vec![
197 self.initial_state(bfe!(1), 5, 0, 0),
198 self.initial_state(bfe!(1), 1, 0, 0),
199 self.initial_state(bfe!(1), 2, 1, 1),
200 ]
201 }
202 }
203
204 #[test]
205 fn test() {
206 for data_type in [
207 DataType::Bfe,
208 DataType::Bool,
209 DataType::U64,
210 DataType::Xfe,
211 DataType::U128,
212 DataType::Digest,
213 DataType::Tuple(vec![DataType::Xfe, DataType::Xfe]),
214 DataType::Tuple(vec![DataType::Digest, DataType::U64]),
215 DataType::Tuple(vec![DataType::Xfe, DataType::Digest]),
216 DataType::Tuple(vec![DataType::Xfe, DataType::Xfe, DataType::Xfe]),
217 DataType::Tuple(vec![DataType::Digest, DataType::Digest]),
218 DataType::Tuple(vec![DataType::Xfe, DataType::Xfe, DataType::Digest]),
219 DataType::Tuple(vec![
220 DataType::Xfe,
221 DataType::Xfe,
222 DataType::Xfe,
223 DataType::Xfe,
224 ]),
225 DataType::Tuple(vec![DataType::Digest, DataType::Digest, DataType::Xfe]),
226 ] {
227 ShadowedAlgorithm::new(SwapUnchecked::new(data_type)).test();
228 }
229 }
230}
231
232#[cfg(test)]
233mod benches {
234 use super::*;
235 use crate::test_prelude::*;
236
237 #[test]
238 fn benchmark() {
239 ShadowedAlgorithm::new(SwapUnchecked::new(DataType::Xfe)).bench();
240 }
241}