sol_parser_sdk/common/
simd_utils.rs

1use wide::*;
2
3/// SIMD-accelerated data parsing utilities
4pub struct SimdUtils;
5
6impl SimdUtils {
7    /// SIMD-accelerated byte array comparison
8    /// For arrays with length >= 16, uses SIMD instructions for fast comparison
9    #[inline(always)]
10    pub fn fast_bytes_equal(a: &[u8], b: &[u8]) -> bool {
11        if a.len() != b.len() {
12            return false;
13        }
14
15        let len = a.len();
16
17        // For small arrays, use standard comparison directly
18        if len < 16 {
19            return a == b;
20        }
21
22        // Use SIMD to process 16-byte chunks
23        let chunks = len / 16;
24        let remainder = len % 16;
25
26        // Process complete 16-byte chunks
27        for i in 0..chunks {
28            let offset = i * 16;
29            let chunk_a = u8x16::from(&a[offset..offset + 16]);
30            let chunk_b = u8x16::from(&b[offset..offset + 16]);
31
32            if !chunk_a.cmp_eq(chunk_b).all() {
33                return false;
34            }
35        }
36
37        // Process remaining bytes
38        if remainder > 0 {
39            let start = chunks * 16;
40            return &a[start..] == &b[start..];
41        }
42
43        true
44    }
45
46    /// Fast discriminator matching, specifically for instruction discriminator comparison
47    #[inline(always)]
48    pub fn fast_discriminator_match(data: &[u8], discriminator: &[u8]) -> bool {
49        if data.len() < discriminator.len() {
50            return false;
51        }
52
53        let disc_len = discriminator.len();
54
55        // Optimize for common discriminator lengths
56        match disc_len {
57            1 => data[0] == discriminator[0],
58            2 => {
59                let data_u16 = u16::from_le_bytes([data[0], data[1]]);
60                let disc_u16 = u16::from_le_bytes([discriminator[0], discriminator[1]]);
61                data_u16 == disc_u16
62            }
63            4 => {
64                let data_u32 = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
65                let disc_u32 = u32::from_le_bytes([
66                    discriminator[0],
67                    discriminator[1],
68                    discriminator[2],
69                    discriminator[3],
70                ]);
71                data_u32 == disc_u32
72            }
73            8 => {
74                let data_u64 = u64::from_le_bytes([
75                    data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
76                ]);
77                let disc_u64 = u64::from_le_bytes([
78                    discriminator[0],
79                    discriminator[1],
80                    discriminator[2],
81                    discriminator[3],
82                    discriminator[4],
83                    discriminator[5],
84                    discriminator[6],
85                    discriminator[7],
86                ]);
87                data_u64 == disc_u64
88            }
89            16 => {
90                // Use SIMD to process 16-byte discriminators
91                let data_chunk = u8x16::from(&data[..16]);
92                let disc_chunk = u8x16::from(discriminator);
93                data_chunk.cmp_eq(disc_chunk).all()
94            }
95            _ => {
96                // For other lengths, use generic SIMD comparison
97                Self::fast_bytes_equal(&data[..disc_len], discriminator)
98            }
99        }
100    }
101
102    /// SIMD-accelerated memory search to find specific patterns in data
103    #[inline(always)]
104    pub fn find_pattern_simd(haystack: &[u8], needle: &[u8]) -> Option<usize> {
105        if needle.is_empty() || haystack.len() < needle.len() {
106            return None;
107        }
108
109        let needle_len = needle.len();
110        let haystack_len = haystack.len();
111
112        // For single-byte search, use optimized method
113        if needle_len == 1 {
114            let target = needle[0];
115            return haystack.iter().position(|&b| b == target);
116        }
117
118        // For multi-byte search, use SIMD acceleration
119        if needle_len <= 16 && haystack_len >= 16 {
120            let first_byte = needle[0];
121            let chunks = (haystack_len - needle_len + 1) / 16;
122
123            for chunk_idx in 0..chunks {
124                let start = chunk_idx * 16;
125                let end = std::cmp::min(start + 16, haystack_len - needle_len + 1);
126
127                // Use SIMD to find first byte matches
128                let chunk = &haystack[start..start + 16];
129                let target_vec = u8x16::splat(first_byte);
130                let chunk_vec = u8x16::from(chunk);
131                let matches = chunk_vec.cmp_eq(target_vec);
132
133                // Check each match position
134                let matches_array: [u8; 16] = matches.into();
135                for i in 0..16 {
136                    if start + i >= end {
137                        break;
138                    }
139
140                    if matches_array[i] != 0 && start + i + needle_len <= haystack_len {
141                        if Self::fast_bytes_equal(
142                            &haystack[start + i..start + i + needle_len],
143                            needle,
144                        ) {
145                            return Some(start + i);
146                        }
147                    }
148                }
149            }
150
151            // Process remaining part
152            let remaining_start = chunks * 16;
153            for i in remaining_start..=(haystack_len - needle_len) {
154                if Self::fast_bytes_equal(&haystack[i..i + needle_len], needle) {
155                    return Some(i);
156                }
157            }
158        } else {
159            // Fallback to standard search
160            for i in 0..=(haystack_len - needle_len) {
161                if Self::fast_bytes_equal(&haystack[i..i + needle_len], needle) {
162                    return Some(i);
163                }
164            }
165        }
166
167        None
168    }
169
170    /// SIMD-accelerated data validation to check if data conforms to specific format
171    #[inline(always)]
172    pub fn validate_data_format(data: &[u8], min_length: usize) -> bool {
173        if data.len() < min_length {
174            return false;
175        }
176
177        true
178    }
179
180    /// Fast checksum calculation (maintains API consistency)
181    #[inline(always)]
182    pub fn fast_checksum(data: &[u8]) -> u32 {
183        // Simplified implementation, directly sum all bytes
184        data.iter().map(|&b| b as u32).sum()
185    }
186
187    /// SIMD-accelerated data copy (for large data blocks)
188    #[inline(always)]
189    pub fn fast_copy(src: &[u8], dst: &mut [u8]) {
190        if src.len() != dst.len() {
191            panic!("Source and destination must have the same length");
192        }
193
194        let len = src.len();
195
196        if len >= 32 {
197            // Use 32-byte SIMD copy
198            let chunks = len / 32;
199
200            for i in 0..chunks {
201                let start = i * 32;
202                let src_chunk1 = u8x16::from(&src[start..start + 16]);
203                let src_chunk2 = u8x16::from(&src[start + 16..start + 32]);
204
205                let chunk1_array: [u8; 16] = src_chunk1.into();
206                let chunk2_array: [u8; 16] = src_chunk2.into();
207
208                dst[start..start + 16].copy_from_slice(&chunk1_array);
209                dst[start + 16..start + 32].copy_from_slice(&chunk2_array);
210            }
211
212            // Process remaining bytes
213            let remaining_start = chunks * 32;
214            dst[remaining_start..].copy_from_slice(&src[remaining_start..]);
215        } else {
216            // For small data, use standard copy
217            dst.copy_from_slice(src);
218        }
219    }
220
221    /// SIMD-accelerated account indices validation
222    /// Validates that all indices in the account index array are less than the total account count
223    #[inline(always)]
224    pub fn validate_account_indices_simd(indices: &[u8], account_count: usize) -> bool {
225        if indices.is_empty() {
226            return true;
227        }
228
229        let max_valid_index = account_count as u8;
230
231        // For small arrays, use standard comparison directly
232        if indices.len() < 16 {
233            return indices.iter().all(|&idx| idx < max_valid_index);
234        }
235
236        // Use SIMD for batch loading and comparison
237        let chunks = indices.len() / 16;
238        let remainder = indices.len() % 16;
239
240        // Process complete 16-byte chunks
241        for i in 0..chunks {
242            let start = i * 16;
243            let indices_chunk = u8x16::from(&indices[start..start + 16]);
244
245            // Convert SIMD vector to array for fast batch checking
246            let indices_array: [u8; 16] = indices_chunk.into();
247
248            // Use unrolled loop for fast comparison, compiler will optimize this
249            if indices_array[0] >= max_valid_index
250                || indices_array[1] >= max_valid_index
251                || indices_array[2] >= max_valid_index
252                || indices_array[3] >= max_valid_index
253                || indices_array[4] >= max_valid_index
254                || indices_array[5] >= max_valid_index
255                || indices_array[6] >= max_valid_index
256                || indices_array[7] >= max_valid_index
257                || indices_array[8] >= max_valid_index
258                || indices_array[9] >= max_valid_index
259                || indices_array[10] >= max_valid_index
260                || indices_array[11] >= max_valid_index
261                || indices_array[12] >= max_valid_index
262                || indices_array[13] >= max_valid_index
263                || indices_array[14] >= max_valid_index
264                || indices_array[15] >= max_valid_index
265            {
266                return false;
267            }
268        }
269
270        // Process remaining bytes
271        if remainder > 0 {
272            let remaining_start = chunks * 16;
273            return indices[remaining_start..].iter().all(|&idx| idx < max_valid_index);
274        }
275
276        true
277    }
278
279    /// SIMD-accelerated instruction data validation
280    /// Validates basic format and length requirements of instruction data
281    #[inline(always)]
282    pub fn validate_instruction_data_simd(
283        data: &[u8],
284        min_length: usize,
285        discriminator_length: usize,
286    ) -> bool {
287        // Basic length check
288        if data.len() < min_length || data.len() < discriminator_length {
289            return false;
290        }
291
292        // Use existing data format validation
293        Self::validate_data_format(data, min_length)
294    }
295}