simd_lookup/
eight_value_lookup.rs

1//! SIMD-accelerated lookup for finding positions in small u32 tables
2//!
3//! This module provides efficient position lookup for checking if a u32 value
4//! exists in a table of up to 8 u32 values, returning the position (0-7) if found
5//! or -1 if not found.
6
7use wide::u32x8;
8
9/// SIMD-accelerated position lookup for finding a u32 value in a table of up to 8 values
10///
11/// This is optimized for the common pattern of finding the index/position of a value
12/// in a small lookup table, which is more useful than simple membership testing.
13pub struct EightValueLookup {
14    table: u32x8,
15    count: usize, // Number of actual values (≤ 8)
16}
17
18impl EightValueLookup {
19    /// Create a new position lookup table from a slice of u32 values
20    ///
21    /// # Panics
22    /// Panics if more than 8 values are provided
23    pub fn new(values: &[u32]) -> Self {
24        assert!(
25            values.len() <= 8,
26            "EightValueLookup supports at most 8 values"
27        );
28
29        let mut array = [0u32; 8];
30        for (i, &val) in values.iter().enumerate() {
31            array[i] = val;
32        }
33
34        Self {
35            table: u32x8::from(array),
36            count: values.len(),
37        }
38    }
39
40    /// Find the position of a u32 value in the lookup table
41    /// Returns the position (0-7) if found, or -1 if not found
42    #[inline]
43    pub fn find_position(&self, value: u32) -> i32 {
44        self.find_position_simd_impl(value)
45    }
46
47    /// Find positions for multiple values at once using SIMD
48    /// Returns an array of positions where each element is the position (0-7) or -1
49    #[inline]
50    pub fn find_positions_batch(&self, values: u32x8) -> [i32; 8] {
51        self.find_positions_batch_simd_impl(values)
52    }
53
54    /// Get the number of values in the lookup table
55    pub fn len(&self) -> usize {
56        self.count
57    }
58
59    /// Check if the lookup table is empty
60    pub fn is_empty(&self) -> bool {
61        self.count == 0
62    }
63
64    /// Get the underlying table as an array (includes padding zeros)
65    pub fn as_array(&self) -> [u32; 8] {
66        self.table.to_array()
67    }
68
69    /// Internal SIMD implementation for single value position lookup
70    #[inline]
71    fn find_position_simd_impl(&self, value: u32) -> i32 {
72        if self.count == 0 {
73            return -1;
74        }
75
76        #[cfg(target_arch = "x86_64")]
77        {
78            self.find_position_simd_avx2(value)
79        }
80
81        #[cfg(target_arch = "aarch64")]
82        {
83            self.find_position_simd_neon(value)
84        }
85
86        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
87        {
88            // Fallback to scalar
89            let table_array = self.table.to_array();
90            for i in 0..self.count {
91                if table_array[i] == value {
92                    return i as i32;
93                }
94            }
95            -1
96        }
97    }
98
99    /// Internal SIMD implementation for batch position lookup
100    #[inline]
101    fn find_positions_batch_simd_impl(&self, values: u32x8) -> [i32; 8] {
102        if self.count == 0 {
103            return [-1; 8];
104        }
105
106        #[cfg(target_arch = "x86_64")]
107        {
108            self.find_positions_batch_simd_avx2(values)
109        }
110
111        #[cfg(target_arch = "aarch64")]
112        {
113            self.find_positions_batch_simd_neon(values)
114        }
115
116        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
117        {
118            // Fallback to scalar
119            let values_array = values.to_array();
120            let table_array = self.table.to_array();
121            let mut result = [-1i32; 8];
122
123            for i in 0..8 {
124                for j in 0..self.count {
125                    if values_array[i] == table_array[j] {
126                        result[i] = j as i32;
127                        break;
128                    }
129                }
130            }
131
132            result
133        }
134    }
135
136    #[cfg(target_arch = "x86_64")]
137    #[inline]
138    fn find_position_simd_avx2(&self, value: u32) -> i32 {
139        unsafe {
140            use std::arch::x86_64::*;
141
142            if is_x86_feature_detected!("avx2") {
143                // Broadcast the input value to all lanes
144                let input_vec = _mm256_set1_epi32(value as i32);
145
146                // Load our table values
147                let table_values = self.table.to_array();
148                let table_vec = _mm256_loadu_si256(table_values.as_ptr() as *const __m256i);
149
150                // Compare all lanes at once
151                let cmp_result = _mm256_cmpeq_epi32(input_vec, table_vec);
152
153                // Extract the comparison mask
154                let mask = _mm256_movemask_ps(_mm256_castsi256_ps(cmp_result));
155
156                // Create a mask for only the valid lanes (based on self.count)
157                let valid_mask = (1u32 << self.count) - 1;
158                let masked_result = (mask as u32) & valid_mask;
159
160                if masked_result == 0 {
161                    -1
162                } else {
163                    // Find the position of the first set bit (trailing zeros)
164                    masked_result.trailing_zeros() as i32
165                }
166            } else {
167                // Fallback to scalar if AVX2 not available
168                let table_array = self.table.to_array();
169                for i in 0..self.count {
170                    if table_array[i] == value {
171                        return i as i32;
172                    }
173                }
174                -1
175            }
176        }
177    }
178
179    #[cfg(target_arch = "x86_64")]
180    #[inline]
181    fn find_positions_batch_simd_avx2(&self, values: u32x8) -> [i32; 8] {
182        unsafe {
183            use std::arch::x86_64::*;
184
185            if is_x86_feature_detected!("avx2") {
186                let values_array = values.to_array();
187                let input_vec = _mm256_loadu_si256(values_array.as_ptr() as *const __m256i);
188
189                let table_values = self.table.to_array();
190                // Note: table_vec loaded for potential future SIMD optimization
191                let _table_vec = _mm256_loadu_si256(table_values.as_ptr() as *const __m256i);
192
193                let mut result = [-1i32; 8];
194
195                // For each table position, check which input values match
196                for table_pos in 0..self.count {
197                    // Broadcast current table value to all lanes
198                    let table_broadcast = _mm256_set1_epi32(table_values[table_pos] as i32);
199
200                    // Compare with all input values
201                    let cmp_result = _mm256_cmpeq_epi32(input_vec, table_broadcast);
202
203                    // Extract mask and update results
204                    let mask = _mm256_movemask_ps(_mm256_castsi256_ps(cmp_result));
205
206                    for i in 0..8 {
207                        if (mask & (1 << i)) != 0 && result[i] == -1 {
208                            // First match for this input value
209                            result[i] = table_pos as i32;
210                        }
211                    }
212                }
213
214                result
215            } else {
216                // Fallback to scalar
217                let values_array = values.to_array();
218                let table_array = self.table.to_array();
219                let mut result = [-1i32; 8];
220
221                for i in 0..8 {
222                    for j in 0..self.count {
223                        if values_array[i] == table_array[j] {
224                            result[i] = j as i32;
225                            break;
226                        }
227                    }
228                }
229
230                result
231            }
232        }
233    }
234
235    #[cfg(target_arch = "aarch64")]
236    #[inline]
237    fn find_position_simd_neon(&self, value: u32) -> i32 {
238        unsafe {
239            use std::arch::aarch64::*;
240
241            // Load our table values (2 NEON vectors of 4 u32 each)
242            let table_values = self.table.to_array();
243            let table_vec1 = vld1q_u32(table_values.as_ptr());
244            let table_vec2 = vld1q_u32(table_values.as_ptr().add(4));
245
246            // Broadcast the input value
247            let input_vec = vdupq_n_u32(value);
248
249            // Compare with both halves
250            let cmp1 = vceqq_u32(input_vec, table_vec1);
251            let cmp2 = vceqq_u32(input_vec, table_vec2);
252
253            // Convert to arrays to find position
254            let cmp1_array: [u32; 4] = std::mem::transmute(cmp1);
255            let cmp2_array: [u32; 4] = std::mem::transmute(cmp2);
256
257            // Check first half
258            for i in 0..4.min(self.count) {
259                if cmp1_array[i] != 0 {
260                    return i as i32;
261                }
262            }
263
264            // Check second half if needed
265            if self.count > 4 {
266                for i in 0..(self.count - 4) {
267                    if cmp2_array[i] != 0 {
268                        return (i + 4) as i32;
269                    }
270                }
271            }
272
273            -1
274        }
275    }
276
277    #[cfg(target_arch = "aarch64")]
278    #[inline]
279    fn find_positions_batch_simd_neon(&self, values: u32x8) -> [i32; 8] {
280        unsafe {
281            use std::arch::aarch64::*;
282
283            let values_array = values.to_array();
284            let input_vec1 = vld1q_u32(values_array.as_ptr());
285            let input_vec2 = vld1q_u32(values_array.as_ptr().add(4));
286
287            let table_values = self.table.to_array();
288
289            let mut result = [-1i32; 8];
290
291            // Check each table position against all input values
292            for table_pos in 0..self.count {
293                let table_broadcast = vdupq_n_u32(table_values[table_pos]);
294
295                let cmp1 = vceqq_u32(input_vec1, table_broadcast);
296                let cmp2 = vceqq_u32(input_vec2, table_broadcast);
297
298                let cmp1_array: [u32; 4] = std::mem::transmute(cmp1);
299                let cmp2_array: [u32; 4] = std::mem::transmute(cmp2);
300
301                // Update results for first half
302                for i in 0..4 {
303                    if cmp1_array[i] != 0 && result[i] == -1 {
304                        result[i] = table_pos as i32;
305                    }
306                }
307
308                // Update results for second half
309                for i in 0..4 {
310                    if cmp2_array[i] != 0 && result[i + 4] == -1 {
311                        result[i + 4] = table_pos as i32;
312                    }
313                }
314            }
315
316            result
317        }
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    #[test]
326    fn test_basic_position_lookup() {
327        let lookup = EightValueLookup::new(&[10, 20, 30, 40]);
328
329        assert_eq!(lookup.find_position(10), 0);
330        assert_eq!(lookup.find_position(20), 1);
331        assert_eq!(lookup.find_position(30), 2);
332        assert_eq!(lookup.find_position(40), 3);
333        assert_eq!(lookup.find_position(5), -1);
334        assert_eq!(lookup.find_position(50), -1);
335    }
336
337    #[test]
338    fn test_full_table() {
339        let lookup = EightValueLookup::new(&[1, 2, 3, 4, 5, 6, 7, 8]);
340
341        for i in 1..=8 {
342            assert_eq!(lookup.find_position(i), (i - 1) as i32);
343        }
344
345        assert_eq!(lookup.find_position(0), -1);
346        assert_eq!(lookup.find_position(9), -1);
347    }
348
349    #[test]
350    fn test_empty_table() {
351        let lookup = EightValueLookup::new(&[]);
352
353        assert_eq!(lookup.find_position(0), -1);
354        assert_eq!(lookup.find_position(1), -1);
355        assert!(lookup.is_empty());
356        assert_eq!(lookup.len(), 0);
357    }
358
359    #[test]
360    fn test_single_value() {
361        let lookup = EightValueLookup::new(&[42]);
362
363        assert_eq!(lookup.find_position(42), 0);
364        assert_eq!(lookup.find_position(41), -1);
365        assert_eq!(lookup.find_position(43), -1);
366        assert_eq!(lookup.len(), 1);
367    }
368
369    #[test]
370    fn test_batch_position_lookup() {
371        let lookup = EightValueLookup::new(&[10, 20, 30, 40, 50]);
372
373        let test_values = u32x8::from([10, 15, 20, 25, 30, 35, 40, 45]);
374        let results = lookup.find_positions_batch(test_values);
375
376        let expected = [0, -1, 1, -1, 2, -1, 3, -1];
377        assert_eq!(results, expected);
378    }
379
380    #[test]
381    fn test_duplicates_return_first_position() {
382        let lookup = EightValueLookup::new(&[10, 20, 10, 30, 20]);
383
384        // Should return the first occurrence
385        assert_eq!(lookup.find_position(10), 0);
386        assert_eq!(lookup.find_position(20), 1);
387        assert_eq!(lookup.find_position(30), 3);
388    }
389
390    #[test]
391    fn test_large_values() {
392        let lookup = EightValueLookup::new(&[
393            u32::MAX - 7,
394            u32::MAX - 5,
395            u32::MAX - 3,
396            u32::MAX - 1,
397            u32::MAX,
398        ]);
399
400        assert_eq!(lookup.find_position(u32::MAX), 4);
401        assert_eq!(lookup.find_position(u32::MAX - 1), 3);
402        assert_eq!(lookup.find_position(u32::MAX - 3), 2);
403        assert_eq!(lookup.find_position(u32::MAX - 5), 1);
404        assert_eq!(lookup.find_position(u32::MAX - 7), 0);
405
406        assert_eq!(lookup.find_position(u32::MAX - 2), -1);
407        assert_eq!(lookup.find_position(u32::MAX - 4), -1);
408    }
409
410    #[test]
411    fn test_batch_vs_single_consistency() {
412        let lookup = EightValueLookup::new(&[5, 15, 25, 35, 45, 55, 65, 75]);
413
414        let test_values = u32x8::from([5, 10, 15, 20, 25, 30, 35, 40]);
415        let batch_results = lookup.find_positions_batch(test_values);
416
417        let test_array = test_values.to_array();
418        for (i, &test_val) in test_array.iter().enumerate() {
419            let single_result = lookup.find_position(test_val);
420            assert_eq!(
421                batch_results[i], single_result,
422                "Mismatch for value {} at index {}: batch={}, single={}",
423                test_val, i, batch_results[i], single_result
424            );
425        }
426    }
427
428    #[test]
429    #[should_panic(expected = "EightValueLookup supports at most 8 values")]
430    fn test_too_many_values() {
431        EightValueLookup::new(&[1, 2, 3, 4, 5, 6, 7, 8, 9]);
432    }
433
434    #[test]
435    fn test_as_array() {
436        let lookup = EightValueLookup::new(&[10, 20, 30]);
437        let array = lookup.as_array();
438
439        assert_eq!(array[0], 10);
440        assert_eq!(array[1], 20);
441        assert_eq!(array[2], 30);
442        // Remaining elements should be 0 (padding)
443        for i in 3..8 {
444            assert_eq!(array[i], 0);
445        }
446    }
447}