1use wide::u32x8;
8
9pub struct EightValueLookup {
14 table: u32x8,
15 count: usize, }
17
18impl EightValueLookup {
19 pub fn new(values: &[u32]) -> Self {
24 assert!(
25 values.len() <= 8,
26 "EightValueLookup supports at most 8 values"
27 );
28
29 let mut array = [0u32; 8];
30 for (i, &val) in values.iter().enumerate() {
31 array[i] = val;
32 }
33
34 Self {
35 table: u32x8::from(array),
36 count: values.len(),
37 }
38 }
39
40 #[inline]
43 pub fn find_position(&self, value: u32) -> i32 {
44 self.find_position_simd_impl(value)
45 }
46
47 #[inline]
50 pub fn find_positions_batch(&self, values: u32x8) -> [i32; 8] {
51 self.find_positions_batch_simd_impl(values)
52 }
53
54 pub fn len(&self) -> usize {
56 self.count
57 }
58
59 pub fn is_empty(&self) -> bool {
61 self.count == 0
62 }
63
64 pub fn as_array(&self) -> [u32; 8] {
66 self.table.to_array()
67 }
68
69 #[inline]
71 fn find_position_simd_impl(&self, value: u32) -> i32 {
72 if self.count == 0 {
73 return -1;
74 }
75
76 #[cfg(target_arch = "x86_64")]
77 {
78 self.find_position_simd_avx2(value)
79 }
80
81 #[cfg(target_arch = "aarch64")]
82 {
83 self.find_position_simd_neon(value)
84 }
85
86 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
87 {
88 let table_array = self.table.to_array();
90 for i in 0..self.count {
91 if table_array[i] == value {
92 return i as i32;
93 }
94 }
95 -1
96 }
97 }
98
99 #[inline]
101 fn find_positions_batch_simd_impl(&self, values: u32x8) -> [i32; 8] {
102 if self.count == 0 {
103 return [-1; 8];
104 }
105
106 #[cfg(target_arch = "x86_64")]
107 {
108 self.find_positions_batch_simd_avx2(values)
109 }
110
111 #[cfg(target_arch = "aarch64")]
112 {
113 self.find_positions_batch_simd_neon(values)
114 }
115
116 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
117 {
118 let values_array = values.to_array();
120 let table_array = self.table.to_array();
121 let mut result = [-1i32; 8];
122
123 for i in 0..8 {
124 for j in 0..self.count {
125 if values_array[i] == table_array[j] {
126 result[i] = j as i32;
127 break;
128 }
129 }
130 }
131
132 result
133 }
134 }
135
136 #[cfg(target_arch = "x86_64")]
137 #[inline]
138 fn find_position_simd_avx2(&self, value: u32) -> i32 {
139 unsafe {
140 use std::arch::x86_64::*;
141
142 if is_x86_feature_detected!("avx2") {
143 let input_vec = _mm256_set1_epi32(value as i32);
145
146 let table_values = self.table.to_array();
148 let table_vec = _mm256_loadu_si256(table_values.as_ptr() as *const __m256i);
149
150 let cmp_result = _mm256_cmpeq_epi32(input_vec, table_vec);
152
153 let mask = _mm256_movemask_ps(_mm256_castsi256_ps(cmp_result));
155
156 let valid_mask = (1u32 << self.count) - 1;
158 let masked_result = (mask as u32) & valid_mask;
159
160 if masked_result == 0 {
161 -1
162 } else {
163 masked_result.trailing_zeros() as i32
165 }
166 } else {
167 let table_array = self.table.to_array();
169 for i in 0..self.count {
170 if table_array[i] == value {
171 return i as i32;
172 }
173 }
174 -1
175 }
176 }
177 }
178
179 #[cfg(target_arch = "x86_64")]
180 #[inline]
181 fn find_positions_batch_simd_avx2(&self, values: u32x8) -> [i32; 8] {
182 unsafe {
183 use std::arch::x86_64::*;
184
185 if is_x86_feature_detected!("avx2") {
186 let values_array = values.to_array();
187 let input_vec = _mm256_loadu_si256(values_array.as_ptr() as *const __m256i);
188
189 let table_values = self.table.to_array();
190 let _table_vec = _mm256_loadu_si256(table_values.as_ptr() as *const __m256i);
192
193 let mut result = [-1i32; 8];
194
195 for table_pos in 0..self.count {
197 let table_broadcast = _mm256_set1_epi32(table_values[table_pos] as i32);
199
200 let cmp_result = _mm256_cmpeq_epi32(input_vec, table_broadcast);
202
203 let mask = _mm256_movemask_ps(_mm256_castsi256_ps(cmp_result));
205
206 for i in 0..8 {
207 if (mask & (1 << i)) != 0 && result[i] == -1 {
208 result[i] = table_pos as i32;
210 }
211 }
212 }
213
214 result
215 } else {
216 let values_array = values.to_array();
218 let table_array = self.table.to_array();
219 let mut result = [-1i32; 8];
220
221 for i in 0..8 {
222 for j in 0..self.count {
223 if values_array[i] == table_array[j] {
224 result[i] = j as i32;
225 break;
226 }
227 }
228 }
229
230 result
231 }
232 }
233 }
234
235 #[cfg(target_arch = "aarch64")]
236 #[inline]
237 fn find_position_simd_neon(&self, value: u32) -> i32 {
238 unsafe {
239 use std::arch::aarch64::*;
240
241 let table_values = self.table.to_array();
243 let table_vec1 = vld1q_u32(table_values.as_ptr());
244 let table_vec2 = vld1q_u32(table_values.as_ptr().add(4));
245
246 let input_vec = vdupq_n_u32(value);
248
249 let cmp1 = vceqq_u32(input_vec, table_vec1);
251 let cmp2 = vceqq_u32(input_vec, table_vec2);
252
253 let cmp1_array: [u32; 4] = std::mem::transmute(cmp1);
255 let cmp2_array: [u32; 4] = std::mem::transmute(cmp2);
256
257 for i in 0..4.min(self.count) {
259 if cmp1_array[i] != 0 {
260 return i as i32;
261 }
262 }
263
264 if self.count > 4 {
266 for i in 0..(self.count - 4) {
267 if cmp2_array[i] != 0 {
268 return (i + 4) as i32;
269 }
270 }
271 }
272
273 -1
274 }
275 }
276
277 #[cfg(target_arch = "aarch64")]
278 #[inline]
279 fn find_positions_batch_simd_neon(&self, values: u32x8) -> [i32; 8] {
280 unsafe {
281 use std::arch::aarch64::*;
282
283 let values_array = values.to_array();
284 let input_vec1 = vld1q_u32(values_array.as_ptr());
285 let input_vec2 = vld1q_u32(values_array.as_ptr().add(4));
286
287 let table_values = self.table.to_array();
288
289 let mut result = [-1i32; 8];
290
291 for table_pos in 0..self.count {
293 let table_broadcast = vdupq_n_u32(table_values[table_pos]);
294
295 let cmp1 = vceqq_u32(input_vec1, table_broadcast);
296 let cmp2 = vceqq_u32(input_vec2, table_broadcast);
297
298 let cmp1_array: [u32; 4] = std::mem::transmute(cmp1);
299 let cmp2_array: [u32; 4] = std::mem::transmute(cmp2);
300
301 for i in 0..4 {
303 if cmp1_array[i] != 0 && result[i] == -1 {
304 result[i] = table_pos as i32;
305 }
306 }
307
308 for i in 0..4 {
310 if cmp2_array[i] != 0 && result[i + 4] == -1 {
311 result[i + 4] = table_pos as i32;
312 }
313 }
314 }
315
316 result
317 }
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 #[test]
326 fn test_basic_position_lookup() {
327 let lookup = EightValueLookup::new(&[10, 20, 30, 40]);
328
329 assert_eq!(lookup.find_position(10), 0);
330 assert_eq!(lookup.find_position(20), 1);
331 assert_eq!(lookup.find_position(30), 2);
332 assert_eq!(lookup.find_position(40), 3);
333 assert_eq!(lookup.find_position(5), -1);
334 assert_eq!(lookup.find_position(50), -1);
335 }
336
337 #[test]
338 fn test_full_table() {
339 let lookup = EightValueLookup::new(&[1, 2, 3, 4, 5, 6, 7, 8]);
340
341 for i in 1..=8 {
342 assert_eq!(lookup.find_position(i), (i - 1) as i32);
343 }
344
345 assert_eq!(lookup.find_position(0), -1);
346 assert_eq!(lookup.find_position(9), -1);
347 }
348
349 #[test]
350 fn test_empty_table() {
351 let lookup = EightValueLookup::new(&[]);
352
353 assert_eq!(lookup.find_position(0), -1);
354 assert_eq!(lookup.find_position(1), -1);
355 assert!(lookup.is_empty());
356 assert_eq!(lookup.len(), 0);
357 }
358
359 #[test]
360 fn test_single_value() {
361 let lookup = EightValueLookup::new(&[42]);
362
363 assert_eq!(lookup.find_position(42), 0);
364 assert_eq!(lookup.find_position(41), -1);
365 assert_eq!(lookup.find_position(43), -1);
366 assert_eq!(lookup.len(), 1);
367 }
368
369 #[test]
370 fn test_batch_position_lookup() {
371 let lookup = EightValueLookup::new(&[10, 20, 30, 40, 50]);
372
373 let test_values = u32x8::from([10, 15, 20, 25, 30, 35, 40, 45]);
374 let results = lookup.find_positions_batch(test_values);
375
376 let expected = [0, -1, 1, -1, 2, -1, 3, -1];
377 assert_eq!(results, expected);
378 }
379
380 #[test]
381 fn test_duplicates_return_first_position() {
382 let lookup = EightValueLookup::new(&[10, 20, 10, 30, 20]);
383
384 assert_eq!(lookup.find_position(10), 0);
386 assert_eq!(lookup.find_position(20), 1);
387 assert_eq!(lookup.find_position(30), 3);
388 }
389
390 #[test]
391 fn test_large_values() {
392 let lookup = EightValueLookup::new(&[
393 u32::MAX - 7,
394 u32::MAX - 5,
395 u32::MAX - 3,
396 u32::MAX - 1,
397 u32::MAX,
398 ]);
399
400 assert_eq!(lookup.find_position(u32::MAX), 4);
401 assert_eq!(lookup.find_position(u32::MAX - 1), 3);
402 assert_eq!(lookup.find_position(u32::MAX - 3), 2);
403 assert_eq!(lookup.find_position(u32::MAX - 5), 1);
404 assert_eq!(lookup.find_position(u32::MAX - 7), 0);
405
406 assert_eq!(lookup.find_position(u32::MAX - 2), -1);
407 assert_eq!(lookup.find_position(u32::MAX - 4), -1);
408 }
409
410 #[test]
411 fn test_batch_vs_single_consistency() {
412 let lookup = EightValueLookup::new(&[5, 15, 25, 35, 45, 55, 65, 75]);
413
414 let test_values = u32x8::from([5, 10, 15, 20, 25, 30, 35, 40]);
415 let batch_results = lookup.find_positions_batch(test_values);
416
417 let test_array = test_values.to_array();
418 for (i, &test_val) in test_array.iter().enumerate() {
419 let single_result = lookup.find_position(test_val);
420 assert_eq!(
421 batch_results[i], single_result,
422 "Mismatch for value {} at index {}: batch={}, single={}",
423 test_val, i, batch_results[i], single_result
424 );
425 }
426 }
427
428 #[test]
429 #[should_panic(expected = "EightValueLookup supports at most 8 values")]
430 fn test_too_many_values() {
431 EightValueLookup::new(&[1, 2, 3, 4, 5, 6, 7, 8, 9]);
432 }
433
434 #[test]
435 fn test_as_array() {
436 let lookup = EightValueLookup::new(&[10, 20, 30]);
437 let array = lookup.as_array();
438
439 assert_eq!(array[0], 10);
440 assert_eq!(array[1], 20);
441 assert_eq!(array[2], 30);
442 for i in 3..8 {
444 assert_eq!(array[i], 0);
445 }
446 }
447}