Skip to main content

sol_parser_sdk/common/
simd_utils.rs

1use wide::*;
2
3/// SIMD-accelerated data parsing utilities
4pub struct SimdUtils;
5
6impl SimdUtils {
7    /// SIMD-accelerated byte array comparison
8    /// For arrays with length >= 16, uses SIMD instructions for fast comparison
9    #[inline(always)]
10    pub fn fast_bytes_equal(a: &[u8], b: &[u8]) -> bool {
11        if a.len() != b.len() {
12            return false;
13        }
14
15        let len = a.len();
16
17        // For small arrays, use standard comparison directly
18        if len < 16 {
19            return a == b;
20        }
21
22        // Use SIMD to process 16-byte chunks
23        let chunks = len / 16;
24        let remainder = len % 16;
25
26        // Process complete 16-byte chunks
27        for i in 0..chunks {
28            let offset = i * 16;
29            let chunk_a = u8x16::from(&a[offset..offset + 16]);
30            let chunk_b = u8x16::from(&b[offset..offset + 16]);
31
32            if !chunk_a.cmp_eq(chunk_b).all() {
33                return false;
34            }
35        }
36
37        // Process remaining bytes
38        if remainder > 0 {
39            let start = chunks * 16;
40            return a[start..] == b[start..];
41        }
42
43        true
44    }
45
46    /// Fast discriminator matching, specifically for instruction discriminator comparison
47    #[inline(always)]
48    pub fn fast_discriminator_match(data: &[u8], discriminator: &[u8]) -> bool {
49        if data.len() < discriminator.len() {
50            return false;
51        }
52
53        let disc_len = discriminator.len();
54
55        // Optimize for common discriminator lengths
56        match disc_len {
57            1 => data[0] == discriminator[0],
58            2 => {
59                let data_u16 = u16::from_le_bytes([data[0], data[1]]);
60                let disc_u16 = u16::from_le_bytes([discriminator[0], discriminator[1]]);
61                data_u16 == disc_u16
62            }
63            4 => {
64                let data_u32 = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
65                let disc_u32 = u32::from_le_bytes([
66                    discriminator[0],
67                    discriminator[1],
68                    discriminator[2],
69                    discriminator[3],
70                ]);
71                data_u32 == disc_u32
72            }
73            8 => {
74                let data_u64 = u64::from_le_bytes([
75                    data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
76                ]);
77                let disc_u64 = u64::from_le_bytes([
78                    discriminator[0],
79                    discriminator[1],
80                    discriminator[2],
81                    discriminator[3],
82                    discriminator[4],
83                    discriminator[5],
84                    discriminator[6],
85                    discriminator[7],
86                ]);
87                data_u64 == disc_u64
88            }
89            16 => {
90                // Use SIMD to process 16-byte discriminators
91                let data_chunk = u8x16::from(&data[..16]);
92                let disc_chunk = u8x16::from(discriminator);
93                data_chunk.cmp_eq(disc_chunk).all()
94            }
95            _ => {
96                // For other lengths, use generic SIMD comparison
97                Self::fast_bytes_equal(&data[..disc_len], discriminator)
98            }
99        }
100    }
101
102    /// SIMD-accelerated memory search to find specific patterns in data
103    #[inline(always)]
104    pub fn find_pattern_simd(haystack: &[u8], needle: &[u8]) -> Option<usize> {
105        if needle.is_empty() || haystack.len() < needle.len() {
106            return None;
107        }
108
109        let needle_len = needle.len();
110        let haystack_len = haystack.len();
111
112        // For single-byte search, use optimized method
113        if needle_len == 1 {
114            let target = needle[0];
115            return haystack.iter().position(|&b| b == target);
116        }
117
118        // For multi-byte search, use SIMD acceleration
119        if needle_len <= 16 && haystack_len >= 16 {
120            let first_byte = needle[0];
121            let chunks = (haystack_len - needle_len + 1) / 16;
122
123            for chunk_idx in 0..chunks {
124                let start = chunk_idx * 16;
125                let end = std::cmp::min(start + 16, haystack_len - needle_len + 1);
126
127                // Use SIMD to find first byte matches
128                let chunk = &haystack[start..start + 16];
129                let target_vec = u8x16::splat(first_byte);
130                let chunk_vec = u8x16::from(chunk);
131                let matches = chunk_vec.cmp_eq(target_vec);
132
133                // Check each match position
134                let matches_array: [u8; 16] = matches.into();
135                for i in 0..16 {
136                    if start + i >= end {
137                        break;
138                    }
139
140                    if matches_array[i] != 0
141                        && start + i + needle_len <= haystack_len
142                        && Self::fast_bytes_equal(
143                            &haystack[start + i..start + i + needle_len],
144                            needle,
145                        )
146                    {
147                        return Some(start + i);
148                    }
149                }
150            }
151
152            // Process remaining part
153            let remaining_start = chunks * 16;
154            for i in remaining_start..=(haystack_len - needle_len) {
155                if Self::fast_bytes_equal(&haystack[i..i + needle_len], needle) {
156                    return Some(i);
157                }
158            }
159        } else {
160            // Fallback to standard search
161            for i in 0..=(haystack_len - needle_len) {
162                if Self::fast_bytes_equal(&haystack[i..i + needle_len], needle) {
163                    return Some(i);
164                }
165            }
166        }
167
168        None
169    }
170
171    /// SIMD-accelerated data validation to check if data conforms to specific format
172    #[inline(always)]
173    pub fn validate_data_format(data: &[u8], min_length: usize) -> bool {
174        if data.len() < min_length {
175            return false;
176        }
177
178        true
179    }
180
181    /// Fast checksum calculation (maintains API consistency)
182    #[inline(always)]
183    pub fn fast_checksum(data: &[u8]) -> u32 {
184        // Simplified implementation, directly sum all bytes
185        data.iter().map(|&b| b as u32).sum()
186    }
187
188    /// SIMD-accelerated data copy (for large data blocks)
189    #[inline(always)]
190    pub fn fast_copy(src: &[u8], dst: &mut [u8]) {
191        if src.len() != dst.len() {
192            panic!("Source and destination must have the same length");
193        }
194
195        let len = src.len();
196
197        if len >= 32 {
198            // Use 32-byte SIMD copy
199            let chunks = len / 32;
200
201            for i in 0..chunks {
202                let start = i * 32;
203                let src_chunk1 = u8x16::from(&src[start..start + 16]);
204                let src_chunk2 = u8x16::from(&src[start + 16..start + 32]);
205
206                let chunk1_array: [u8; 16] = src_chunk1.into();
207                let chunk2_array: [u8; 16] = src_chunk2.into();
208
209                dst[start..start + 16].copy_from_slice(&chunk1_array);
210                dst[start + 16..start + 32].copy_from_slice(&chunk2_array);
211            }
212
213            // Process remaining bytes
214            let remaining_start = chunks * 32;
215            dst[remaining_start..].copy_from_slice(&src[remaining_start..]);
216        } else {
217            // For small data, use standard copy
218            dst.copy_from_slice(src);
219        }
220    }
221
222    /// SIMD-accelerated account indices validation
223    /// Validates that all indices in the account index array are less than the total account count
224    #[inline(always)]
225    pub fn validate_account_indices_simd(indices: &[u8], account_count: usize) -> bool {
226        if indices.is_empty() {
227            return true;
228        }
229
230        let max_valid_index = account_count as u8;
231
232        // For small arrays, use standard comparison directly
233        if indices.len() < 16 {
234            return indices.iter().all(|&idx| idx < max_valid_index);
235        }
236
237        // Use SIMD for batch loading and comparison
238        let chunks = indices.len() / 16;
239        let remainder = indices.len() % 16;
240
241        // Process complete 16-byte chunks
242        for i in 0..chunks {
243            let start = i * 16;
244            let indices_chunk = u8x16::from(&indices[start..start + 16]);
245
246            // Convert SIMD vector to array for fast batch checking
247            let indices_array: [u8; 16] = indices_chunk.into();
248
249            // Use unrolled loop for fast comparison, compiler will optimize this
250            if indices_array[0] >= max_valid_index
251                || indices_array[1] >= max_valid_index
252                || indices_array[2] >= max_valid_index
253                || indices_array[3] >= max_valid_index
254                || indices_array[4] >= max_valid_index
255                || indices_array[5] >= max_valid_index
256                || indices_array[6] >= max_valid_index
257                || indices_array[7] >= max_valid_index
258                || indices_array[8] >= max_valid_index
259                || indices_array[9] >= max_valid_index
260                || indices_array[10] >= max_valid_index
261                || indices_array[11] >= max_valid_index
262                || indices_array[12] >= max_valid_index
263                || indices_array[13] >= max_valid_index
264                || indices_array[14] >= max_valid_index
265                || indices_array[15] >= max_valid_index
266            {
267                return false;
268            }
269        }
270
271        // Process remaining bytes
272        if remainder > 0 {
273            let remaining_start = chunks * 16;
274            return indices[remaining_start..].iter().all(|&idx| idx < max_valid_index);
275        }
276
277        true
278    }
279
280    /// SIMD-accelerated instruction data validation
281    /// Validates basic format and length requirements of instruction data
282    #[inline(always)]
283    pub fn validate_instruction_data_simd(
284        data: &[u8],
285        min_length: usize,
286        discriminator_length: usize,
287    ) -> bool {
288        // Basic length check
289        if data.len() < min_length || data.len() < discriminator_length {
290            return false;
291        }
292
293        // Use existing data format validation
294        Self::validate_data_format(data, min_length)
295    }
296}