simd_lookup/
small_table.rs

1//! SIMD enabled efficient small table lookups - for 64 entries or 64K entries.
2//! May be 2-D lookups as well.
3//!
4//! # CPU Feature Requirements
5//!
6//! ## Table64 (64-entry lookup table)
7//!
8//! **`Table64` is primarily optimized for ARM NEON** and provides excellent performance on Apple Silicon
9//! and other ARMv8+ CPUs. It also supports Intel AVX-512 on newer CPUs.
10//!
11//! ### ARM aarch64 (Primary Optimization Target)
12//! - **Optimal**: Uses ARM NEON `TBL4` instruction (`vqtbl4q_u8`)
13//!   - Native hardware support on all ARMv8+ CPUs (including Apple M1/M2/M3)
14//!   - Extremely efficient single-instruction 64-byte table lookup
15//!   - No fallback needed - full SIMD acceleration on ARM
16//!   - The `TBL4` instruction can perform 64-entry lookups in a single operation
17//!
18//! ### Intel x86_64
19//! - **Optimal**: Requires **AVX512BW** + **AVX512VBMI**
20//!   - Uses `VPERMB` instruction (`_mm512_permutexvar_epi8`) for 64-byte table lookups
21//!   - Available on: Intel Ice Lake, Tiger Lake, and later (not available on Skylake-X)
22//!   - Fallback: Scalar lookup (works on all x86_64 CPUs)
23//!
24//! ## Table2dU8xU8 (2D lookup table, up to 64K entries)
25//!
26//! ### Intel x86_64
27//! - **Optimal**: Requires **AVX512F** + **AVX512BW** (via `simd_gather` module)
28//!   - Uses `VGATHERDPS` + `VPMOVDB` for parallel lookups
29//!   - Available on: Intel Skylake-X (Xeon), Ice Lake, Tiger Lake, and later
30//!   - Fallback: Scalar lookup (works on all architectures)
31//!
32//! ### ARM aarch64
33//! - Uses scalar fallback (NEON gather is not significantly faster than scalar for this use case)
34
35use crate::simd_gather::gather_u32index_u8;
36use crate::wide_utils::WideUtilsExt;
37use std::fmt;
38use wide::{u8x16, u16x16, u32x16};
39
40#[cfg(target_arch = "aarch64")]
41use core::arch::aarch64::{uint8x16x4_t, vld1q_u8, vqtbl4q_u8, vst1q_u8};
42
43#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
44use core::arch::x86_64::{
45    __m128i, __m512i, _mm_loadu_si128, _mm_storeu_si128, _mm512_castsi128_si512,
46    _mm512_castsi512_si128, _mm512_loadu_si512, _mm512_permutexvar_epi8, _mm512_storeu_si512,
47};
48
49#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
50use std::arch::is_x86_feature_detected as det;
51
52//------------------- SIMD small table lookup functions (ARM NEON VTBL etc.) ---------------------------------------
53// The idea is optimized small table (say <=64 entries) lookup, which can be done in only a few instructions.
54// Or, you can think of it as an 8x8 lookup table.
55
56/// A SIMD-optimized 64-entry lookup table, able to do extremely efficient lookups in ARM NEON and Intel AVX-512VBMI.
57///
58/// # 2D Interpretation
59///
60/// `Table64` can also be viewed as an 8×8 two-dimensional table stored in row-major order:
61///
62/// ```text
63///        col 0  col 1  col 2  col 3  col 4  col 5  col 6  col 7
64/// row 0:   0      1      2      3      4      5      6      7
65/// row 1:   8      9     10     11     12     13     14     15
66/// row 2:  16     17     18     19     20     21     22     23
67/// row 3:  24     25     26     27     28     29     30     31
68/// row 4:  32     33     34     35     36     37     38     39
69/// row 5:  40     41     42     43     44     45     46     47
70/// row 6:  48     49     50     51     52     53     54     55
71/// row 7:  56     57     58     59     60     61     62     63
72/// ```
73///
74/// Use [`lookup_one_2d`](Self::lookup_one_2d) to perform lookups using (row, column) coordinates.
75pub struct Table64 {
76    #[cfg(target_arch = "aarch64")]
77    neon_tbl: uint8x16x4_t,
78
79    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
80    bytes: [u8; 64],
81
82    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
83    zmm: Option<__m512i>, // preloaded 64B table for AVX-512VBMI
84}
85
86impl Table64 {
87    #[inline]
88    pub fn new(table: &[u8; 64]) -> Self {
89        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
90        {
91            let zmm = if is_x86_avx512_vbmi() {
92                unsafe {
93                    let z = _mm512_loadu_si512(table.as_ptr() as *const _);
94                    Some(z)
95                }
96            } else {
97                None
98            };
99
100            Self { bytes: *table, zmm }
101        }
102
103        #[cfg(target_arch = "aarch64")]
104        {
105            Self {
106                neon_tbl: unsafe {
107                    let t0 = vld1q_u8(table.as_ptr());
108                    let t1 = vld1q_u8(table.as_ptr().add(16));
109                    let t2 = vld1q_u8(table.as_ptr().add(32));
110                    let t3 = vld1q_u8(table.as_ptr().add(48));
111                    uint8x16x4_t(t0, t1, t2, t3)
112                },
113            }
114        }
115    }
116
117    /// Single-vector lookup: each byte of `idx` (0..63) selects from this 64B table.
118    /// Returns a `u8x16` with the looked-up values.
119    #[inline]
120    pub fn lookup_one(&self, idx: u8x16) -> u8x16 {
121        #[cfg(target_arch = "aarch64")]
122        unsafe {
123            let i = vld1q_u8(idx.as_array().as_ptr());
124            let r = vqtbl4q_u8(self.neon_tbl, i);
125            let mut out = [0u8; 16];
126            vst1q_u8(out.as_mut_ptr(), r);
127            u8x16::from(out)
128        }
129
130        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
131        {
132            if let Some(tzmm) = self.zmm {
133                unsafe {
134                    // Load only 16 bytes (safe) into XMM register
135                    let iv_128 = _mm_loadu_si128(idx.as_array().as_ptr() as *const __m128i);
136                    // Zero-cost cast to ZMM (upper bytes undefined, but we don't use them)
137                    let iv = _mm512_castsi128_si512(iv_128);
138                    // VPERMB: only first 16 result bytes are valid
139                    let rv = _mm512_permutexvar_epi8(iv, tzmm);
140                    // Extract low 128 bits (zero latency - register rename)
141                    let rv_128 = _mm512_castsi512_si128(rv);
142                    // Store only 16 bytes
143                    let mut result = [0u8; 16];
144                    _mm_storeu_si128(result.as_mut_ptr() as *mut __m128i, rv_128);
145                    u8x16::from(result)
146                }
147            } else {
148                scalar_lookup_1x16(&self.bytes, idx)
149            }
150        }
151
152        #[cfg(not(any(target_arch = "aarch64", target_arch = "x86", target_arch = "x86_64")))]
153        compile_error!(
154            "Table64::lookup_one is implemented for aarch64 (NEON) and x86/x86_64 (AVX-512VBMI)."
155        );
156    }
157
158    /// 2D lookup: treats the 64-entry table as an 8×8 row-major matrix.
159    ///
160    /// Each lane computes `index = row * 8 + col` and looks up the corresponding value.
161    ///
162    /// # Arguments
163    /// - `rows`: Row indices (0..7) for each of the 16 lanes
164    /// - `cols`: Column indices (0..7) for each of the 16 lanes
165    ///
166    /// # Panics (debug only)
167    /// Debug-asserts that all row and column values are in range 0..8.
168    ///
169    /// # Example
170    /// ```ignore
171    /// let table = Table64::new(&data);
172    /// let rows = u8x16::from([0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]);
173    /// let cols = u8x16::from([0, 0, 0, 0, 0, 0, 0, 0, 7, 7, 7, 7, 7, 7, 7, 7]);
174    /// let result = table.lookup_one_2d(rows, cols);
175    /// // Looks up indices [0, 8, 16, 24, 32, 40, 48, 56, 7, 15, 23, 31, 39, 47, 55, 63]
176    /// ```
177    #[inline]
178    pub fn lookup_one_2d(&self, rows: u8x16, cols: u8x16) -> u8x16 {
179        debug_assert!(
180            rows.to_array().iter().all(|&r| r < 8),
181            "All row indices must be < 8"
182        );
183        debug_assert!(
184            cols.to_array().iter().all(|&c| c < 8),
185            "All column indices must be < 8"
186        );
187
188        // index = row * 8 + col
189        // Use double().double().double() for efficient ×8 via SIMD addition
190        // x86-64 does not have SIMD support for u8 multiply unfortunately
191        let idx = rows.double().double().double() + cols;
192        self.lookup_one(idx)
193    }
194
195    /// Get the underlying bytes array (for debugging/display purposes).
196    /// This extracts the data from platform-specific storage.
197    #[inline]
198    fn as_bytes(&self) -> [u8; 64] {
199        #[cfg(target_arch = "aarch64")]
200        {
201            unsafe {
202                let mut bytes = [0u8; 64];
203                vst1q_u8(bytes.as_mut_ptr(), self.neon_tbl.0);
204                vst1q_u8(bytes.as_mut_ptr().add(16), self.neon_tbl.1);
205                vst1q_u8(bytes.as_mut_ptr().add(32), self.neon_tbl.2);
206                vst1q_u8(bytes.as_mut_ptr().add(48), self.neon_tbl.3);
207                bytes
208            }
209        }
210
211        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
212        {
213            self.bytes
214        }
215    }
216
217    /// Dynamic lookup: each byte of `idx[k]` (0..63) selects from this 64B table.
218    /// - Requires: `idx.len() == out.len()`
219    /// - No element tails (I/O is in whole `u8x16` blocks).
220    #[inline]
221    pub fn lookup(&self, idx: &[u8x16], out: &mut [u8x16]) {
222        assert_eq!(idx.len(), out.len());
223
224        #[cfg(target_arch = "aarch64")]
225        unsafe {
226            // Treat &[u8x16] as a flat &[u8] for direct loads/stores.
227            let idx_bytes = idx.as_ptr() as *const u8;
228            let out_bytes = out.as_mut_ptr() as *mut u8;
229
230            for b in 0..idx.len() {
231                let i_ptr = idx_bytes.add(b * 16);
232                let o_ptr = out_bytes.add(b * 16);
233
234                let i = vld1q_u8(i_ptr);
235                let r = vqtbl4q_u8(self.neon_tbl, i); // 64-entry dynamic table
236                vst1q_u8(o_ptr, r);
237            }
238        }
239
240        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
241        unsafe {
242            let mut i = 0usize;
243            if let Some(tzmm) = self.zmm {
244                // Process 4×u8x16 at a time (64 bytes) with one vpermb.
245                let idx_bytes = idx.as_ptr() as *const u8;
246                let out_bytes = out.as_mut_ptr() as *mut u8;
247
248                while i + 4 <= idx.len() {
249                    let off = i * 16;
250                    let iv = _mm512_loadu_si512(idx_bytes.add(off) as *const __m512i);
251                    let rv = _mm512_permutexvar_epi8(iv, tzmm);
252                    _mm512_storeu_si512(out_bytes.add(off) as *mut __m512i, rv);
253                    i += 4;
254                }
255            }
256
257            // Handle remainder blocks — scalar per 16B block; still no per-byte tails.
258            for k in i..idx.len() {
259                out[k] = scalar_lookup_1x16(&self.bytes, idx[k]);
260            }
261        }
262
263        #[cfg(not(any(target_arch = "aarch64", target_arch = "x86", target_arch = "x86_64")))]
264        compile_error!(
265            "Table64::lookup is implemented for aarch64 (NEON) and x86/x86_64 (AVX-512VBMI)."
266        );
267    }
268}
269
270impl Clone for Table64 {
271    fn clone(&self) -> Self {
272        let bytes = self.as_bytes();
273        Self::new(&bytes)
274    }
275}
276
277impl Default for Table64 {
278    fn default() -> Self {
279        Self::new(&[0u8; 64])
280    }
281}
282
283impl fmt::Debug for Table64 {
284    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
285        let bytes = self.as_bytes();
286        writeln!(f, "Table64 {{")?;
287        writeln!(f, "        col 0  col 1  col 2  col 3  col 4  col 5  col 6  col 7")?;
288        for row in 0..8 {
289            write!(f, "row {}: ", row)?;
290            for col in 0..8 {
291                let idx = row * 8 + col;
292                write!(f, "{:5} ", bytes[idx])?;
293            }
294            writeln!(f)?;
295        }
296        write!(f, "}}")
297    }
298}
299
300// =============================================================================
301// Table2dU8xU8 - 2D lookup table with up to 64K entries (256×256)
302// =============================================================================
303
304/// A 2D SIMD lookup table for `u8 × u8` coordinates, supporting up to 64K entries.
305///
306/// This table stores data in row-major order and uses SIMD gather operations for
307/// efficient parallel lookups. Each lookup takes a row index (0..num_rows) and
308/// column index (0..num_cols), both as u8, and returns the corresponding value.
309///
310/// # Index Calculation
311///
312/// For row `r` and column `c`, the flat index is: `index = r * num_cols + c`
313///
314/// Since row and column are both u8 (max 255), and num_cols is at most 256,
315/// the maximum index is 255 * 256 + 255 = 65535, which fits in u16.
316///
317/// # Example
318///
319/// ```ignore
320/// // Create a 16x16 multiplication table
321/// let mut data = vec![0u8; 256];
322/// for r in 0..16u8 {
323///     for c in 0..16u8 {
324///         data[(r as usize) * 16 + (c as usize)] = r.wrapping_mul(c);
325///     }
326/// }
327/// let table = Table2dU8xU8::from_flat(&data, 16);
328///
329/// // Look up multiple (row, col) pairs in parallel
330/// let rows = u8x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
331/// let cols = u8x16::splat(5);  // All looking up column 5
332/// let result = table.lookup_one(rows, cols);
333/// // result[i] = i * 5
334/// ```
335#[derive(Clone, Default)]
336pub struct Table2dU8xU8 {
337    data: Vec<u8>,
338    num_cols: u16,
339}
340
341impl Table2dU8xU8 {
342    /// Create a 2D table from a flat slice with the given number of columns.
343    ///
344    /// The data is stored in row-major order: `data[row * num_cols + col]`.
345    ///
346    /// # Arguments
347    /// - `data`: Flat slice of values, length must be `num_rows * num_cols`
348    /// - `num_cols`: Number of columns per row (1..=256)
349    ///
350    /// # Panics
351    /// - Panics if `num_cols` is 0 or greater than 256
352    /// - Panics if `data.len()` is not a multiple of `num_cols`
353    /// - Panics if `data.len() > 65536`
354    #[inline]
355    pub fn from_flat(data: &[u8], num_cols: usize) -> Self {
356        assert!(num_cols > 0 && num_cols <= 256, "num_cols must be 1..=256");
357        assert!(data.len() % num_cols == 0, "data length must be multiple of num_cols");
358        assert!(data.len() <= 65536, "data length must be <= 65536 (64K entries)");
359
360        Self {
361            data: data.to_vec(),
362            num_cols: num_cols as u16,
363        }
364    }
365
366    /// Create a 2D table from a 2D matrix (Vec of rows).
367    ///
368    /// All rows must have the same length.
369    ///
370    /// # Panics
371    /// - Panics if the matrix is empty
372    /// - Panics if rows have different lengths
373    /// - Panics if total size exceeds 65536
374    #[inline]
375    pub fn from_2d(matrix: &[&[u8]]) -> Self {
376        assert!(!matrix.is_empty(), "matrix cannot be empty");
377        let num_cols = matrix[0].len();
378        assert!(num_cols > 0 && num_cols <= 256, "num_cols must be 1..=256");
379        assert!(matrix.iter().all(|row| row.len() == num_cols), "all rows must have same length");
380        assert!(matrix.len() * num_cols <= 65536, "total size must be <= 65536");
381
382        let mut data = Vec::with_capacity(matrix.len() * num_cols);
383        for row in matrix {
384            data.extend_from_slice(row);
385        }
386
387        Self {
388            data,
389            num_cols: num_cols as u16,
390        }
391    }
392
393    /// Returns the number of columns per row.
394    #[inline]
395    pub fn num_cols(&self) -> usize {
396        self.num_cols as usize
397    }
398
399    /// Returns the number of rows in the table.
400    #[inline]
401    pub fn num_rows(&self) -> usize {
402        self.data.len() / self.num_cols as usize
403    }
404
405    /// Returns the total number of entries in the table.
406    #[inline]
407    pub fn len(&self) -> usize {
408        self.data.len()
409    }
410
411    /// Returns true if the table is empty.
412    #[inline]
413    pub fn is_empty(&self) -> bool {
414        self.data.is_empty()
415    }
416
417    /// Look up 16 values in parallel using (row, col) coordinates.
418    ///
419    /// Computes `result[i] = table[rows[i]][cols[i]]` for all 16 lanes.
420    ///
421    /// # Arguments
422    /// - `rows`: Row indices (0..num_rows) for each of the 16 lanes
423    /// - `cols`: Column indices (0..num_cols) for each of the 16 lanes
424    ///
425    /// # Safety
426    /// In debug mode, asserts that all indices are in bounds.
427    /// In release mode, out-of-bounds access is undefined behavior.
428    #[inline]
429    pub fn lookup_one(&self, rows: u8x16, cols: u8x16) -> u8x16 {
430        // Widen u8x16 → u16x16 for arithmetic
431        let rows_u16: u16x16 = u16x16::from(rows);
432        let cols_u16: u16x16 = u16x16::from(cols);
433        let num_cols_u16 = u16x16::splat(self.num_cols);
434
435        // index = row * num_cols + col (all in u16x16)
436        let indices_u16 = rows_u16 * num_cols_u16 + cols_u16;
437
438        // Widen u16x16 → u32x16 for gather
439        let indices_u32: u32x16 = u32x16::from(indices_u16);
440
441        // Debug bounds check
442        #[cfg(debug_assertions)]
443        {
444            let idx_arr = indices_u32.to_array();
445            for (i, &idx) in idx_arr.iter().enumerate() {
446                debug_assert!(
447                    (idx as usize) < self.data.len(),
448                    "Index out of bounds at lane {}: {} >= {}",
449                    i, idx, self.data.len()
450                );
451            }
452        }
453
454        // Use SIMD gather (AVX-512 on x86, scalar fallback elsewhere)
455        gather_u32index_u8(indices_u32, &self.data, 1)
456    }
457
458    /// Scalar lookup for a single (row, col) coordinate.
459    ///
460    /// # Panics
461    /// Panics if row or col is out of bounds.
462    #[inline]
463    pub fn get(&self, row: u8, col: u8) -> u8 {
464        let index = (row as usize) * (self.num_cols as usize) + (col as usize);
465        self.data[index]
466    }
467}
468
469impl fmt::Debug for Table2dU8xU8 {
470    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
471        let num_rows = self.num_rows();
472        let num_cols = self.num_cols as usize;
473
474        writeln!(f, "Table2dU8xU8 {{")?;
475        writeln!(f, "  dimensions: {} rows × {} cols", num_rows, num_cols)?;
476
477        if self.data.is_empty() {
478            return write!(f, "  (empty)}}");
479        }
480
481        // Limit display to reasonable size: max 20 rows and 20 cols
482        const MAX_DISPLAY_ROWS: usize = 20;
483        const MAX_DISPLAY_COLS: usize = 20;
484
485        let display_rows = num_rows.min(MAX_DISPLAY_ROWS);
486        let display_cols = num_cols.min(MAX_DISPLAY_COLS);
487        let show_row_ellipsis = num_rows > MAX_DISPLAY_ROWS;
488        let show_col_ellipsis = num_cols > MAX_DISPLAY_COLS;
489
490        // Print column headers
491        write!(f, "  ")?;
492        for col in 0..display_cols {
493            write!(f, " col{:3}", col)?;
494        }
495        if show_col_ellipsis {
496            write!(f, " ...")?;
497        }
498        writeln!(f)?;
499
500        // Print rows
501        for row in 0..display_rows {
502            write!(f, "  row{:3}:", row)?;
503            for col in 0..display_cols {
504                let idx = row * num_cols + col;
505                write!(f, "{:5}", self.data[idx])?;
506            }
507            if show_col_ellipsis {
508                write!(f, " ...")?;
509            }
510            writeln!(f)?;
511        }
512
513        if show_row_ellipsis {
514            writeln!(f, "  ...")?;
515        }
516
517        write!(f, "}}")
518    }
519}
520
521// ------------------
522// Helpers
523// ------------------
524
525#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
526#[inline]
527fn is_x86_avx512_vbmi() -> bool {
528    det!("avx512bw") && det!("avx512vbmi")
529}
530
531/// Scalar per-vector fallback: takes/returns `u8x16`; no element tails.
532/// Preconditions: every lane < 64.
533#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
534#[inline]
535fn scalar_lookup_1x16(table: &[u8; 64], idx: u8x16) -> u8x16 {
536    let i = idx.to_array();
537    debug_assert!(i.iter().all(|&x| x < 64));
538    let out = [
539        table[i[0] as usize],
540        table[i[1] as usize],
541        table[i[2] as usize],
542        table[i[3] as usize],
543        table[i[4] as usize],
544        table[i[5] as usize],
545        table[i[6] as usize],
546        table[i[7] as usize],
547        table[i[8] as usize],
548        table[i[9] as usize],
549        table[i[10] as usize],
550        table[i[11] as usize],
551        table[i[12] as usize],
552        table[i[13] as usize],
553        table[i[14] as usize],
554        table[i[15] as usize],
555    ];
556    u8x16::from(out)
557}
558
559#[cfg(test)]
560mod tests {
561    use super::*;
562
563    fn create_test_table() -> [u8; 64] {
564        let mut table = [0u8; 64];
565        for i in 0..64 {
566            table[i] = (i * 3 + 7) as u8; // Pattern: 7, 10, 13, 16, ...
567        }
568        table
569    }
570
571    #[test]
572    fn test_table64_new() {
573        let table_data = create_test_table();
574        let table = Table64::new(&table_data);
575        println!("\n{:?}", table);
576        // Just ensure construction doesn't panic
577    }
578
579    #[test]
580    fn test_lookup_one_basic() {
581        let table_data = create_test_table();
582        let table = Table64::new(&table_data);
583
584        // Lookup indices 0-15
585        let idx = u8x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
586        let result = table.lookup_one(idx);
587        let result_arr = result.to_array();
588
589        // Verify each lookup
590        for i in 0..16 {
591            assert_eq!(
592                result_arr[i], table_data[i],
593                "Mismatch at index {}: expected {}, got {}",
594                i, table_data[i], result_arr[i]
595            );
596        }
597    }
598
599    #[test]
600    fn test_lookup_one_scattered_indices() {
601        let table_data = create_test_table();
602        let table = Table64::new(&table_data);
603
604        // Scattered indices across the table
605        let idx = u8x16::from([0, 63, 32, 16, 48, 1, 62, 31, 15, 47, 8, 56, 4, 60, 20, 40]);
606        let result = table.lookup_one(idx);
607        let result_arr = result.to_array();
608        let idx_arr = idx.to_array();
609
610        for i in 0..16 {
611            assert_eq!(
612                result_arr[i],
613                table_data[idx_arr[i] as usize],
614                "Mismatch at position {}: idx={}, expected {}, got {}",
615                i,
616                idx_arr[i],
617                table_data[idx_arr[i] as usize],
618                result_arr[i]
619            );
620        }
621    }
622
623    #[test]
624    fn test_lookup_one_all_same_index() {
625        let table_data = create_test_table();
626        let table = Table64::new(&table_data);
627
628        // All indices are the same
629        let idx = u8x16::splat(42);
630        let result = table.lookup_one(idx);
631        let result_arr = result.to_array();
632
633        let expected = table_data[42];
634        for i in 0..16 {
635            assert_eq!(
636                result_arr[i], expected,
637                "All lookups should return the same value"
638            );
639        }
640    }
641
642    #[test]
643    fn test_lookup_batch() {
644        let table_data = create_test_table();
645        let table = Table64::new(&table_data);
646
647        let indices = vec![
648            u8x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]),
649            u8x16::from([16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]),
650            u8x16::from([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]),
651            u8x16::from([48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]),
652        ];
653        let mut output = vec![u8x16::splat(0); 4];
654
655        table.lookup(&indices, &mut output);
656
657        // Verify all lookups
658        for (vec_idx, out_vec) in output.iter().enumerate() {
659            let out_arr = out_vec.to_array();
660            for lane in 0..16 {
661                let table_idx = vec_idx * 16 + lane;
662                assert_eq!(
663                    out_arr[lane], table_data[table_idx],
664                    "Mismatch at vec {}, lane {}: expected {}, got {}",
665                    vec_idx, lane, table_data[table_idx], out_arr[lane]
666                );
667            }
668        }
669    }
670
671    #[test]
672    fn test_lookup_one_matches_lookup_batch() {
673        let table_data = create_test_table();
674        let table = Table64::new(&table_data);
675
676        let idx = u8x16::from([5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 0, 32, 63, 1]);
677
678        // Single lookup
679        let single_result = table.lookup_one(idx);
680
681        // Batch lookup with single element
682        let mut batch_output = vec![u8x16::splat(0); 1];
683        table.lookup(&[idx], &mut batch_output);
684
685        assert_eq!(
686            single_result.to_array(),
687            batch_output[0].to_array(),
688            "lookup_one and lookup should produce the same result"
689        );
690    }
691
692    #[test]
693    fn test_identity_table() {
694        // Create an identity table where table[i] = i
695        let mut table_data = [0u8; 64];
696        for i in 0..64 {
697            table_data[i] = i as u8;
698        }
699        let table = Table64::new(&table_data);
700
701        let idx = u8x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
702        let result = table.lookup_one(idx);
703
704        assert_eq!(idx.to_array(), result.to_array(), "Identity table should return input indices");
705    }
706
707    // ==================== 2D Lookup Tests ====================
708
709    /// Create an 8x8 table where table[row][col] = row * 10 + col
710    /// This makes it easy to verify 2D lookups: result should be row*10 + col
711    fn create_2d_test_table() -> [u8; 64] {
712        let mut table = [0u8; 64];
713        for row in 0..8 {
714            for col in 0..8 {
715                table[row * 8 + col] = (row * 10 + col) as u8;
716            }
717        }
718        table
719    }
720
721    #[test]
722    fn test_lookup_one_2d_basic() {
723        let table_data = create_2d_test_table();
724        let table = Table64::new(&table_data);
725
726        // Lookup first row (row=0, cols=0..7) and second row (row=1, cols=0..7)
727        let rows = u8x16::from([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]);
728        let cols = u8x16::from([0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]);
729
730        let result = table.lookup_one_2d(rows, cols);
731        let result_arr = result.to_array();
732
733        // First 8: row 0, should be 0, 1, 2, 3, 4, 5, 6, 7
734        for col in 0..8 {
735            assert_eq!(result_arr[col], col as u8, "Row 0, col {}", col);
736        }
737        // Next 8: row 1, should be 10, 11, 12, 13, 14, 15, 16, 17
738        for col in 0..8 {
739            assert_eq!(result_arr[8 + col], (10 + col) as u8, "Row 1, col {}", col);
740        }
741    }
742
743    #[test]
744    fn test_lookup_one_2d_diagonal() {
745        let table_data = create_2d_test_table();
746        let table = Table64::new(&table_data);
747
748        // Diagonal: (0,0), (1,1), (2,2), ..., (7,7), then reverse diagonal
749        let rows = u8x16::from([0, 1, 2, 3, 4, 5, 6, 7, 7, 6, 5, 4, 3, 2, 1, 0]);
750        let cols = u8x16::from([0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]);
751
752        let result = table.lookup_one_2d(rows, cols);
753        let result_arr = result.to_array();
754
755        // Main diagonal: row*10 + col where row == col
756        for i in 0..8 {
757            let expected = (i * 10 + i) as u8; // 0, 11, 22, 33, 44, 55, 66, 77
758            assert_eq!(result_arr[i], expected, "Main diagonal position {}", i);
759        }
760
761        // Anti-diagonal part: row=7-i, col=i
762        let expected_anti = [70, 61, 52, 43, 34, 25, 16, 7u8];
763        for i in 0..8 {
764            assert_eq!(result_arr[8 + i], expected_anti[i], "Anti-diagonal position {}", i);
765        }
766    }
767
768    #[test]
769    fn test_lookup_one_2d_corners() {
770        let table_data = create_2d_test_table();
771        let table = Table64::new(&table_data);
772
773        // Test all four corners repeated
774        let rows = u8x16::from([0, 0, 7, 7, 0, 0, 7, 7, 0, 0, 7, 7, 0, 0, 7, 7]);
775        let cols = u8x16::from([0, 7, 0, 7, 0, 7, 0, 7, 0, 7, 0, 7, 0, 7, 0, 7]);
776
777        let result = table.lookup_one_2d(rows, cols);
778        let result_arr = result.to_array();
779
780        // Expected: (0,0)=0, (0,7)=7, (7,0)=70, (7,7)=77
781        let expected = [0u8, 7, 70, 77, 0, 7, 70, 77, 0, 7, 70, 77, 0, 7, 70, 77];
782        assert_eq!(result_arr, expected, "Corner lookups");
783    }
784
785    #[test]
786    fn test_lookup_one_2d_same_row() {
787        let table_data = create_2d_test_table();
788        let table = Table64::new(&table_data);
789
790        // All from row 5
791        let rows = u8x16::splat(5);
792        let cols = u8x16::from([0, 1, 2, 3, 4, 5, 6, 7, 7, 6, 5, 4, 3, 2, 1, 0]);
793
794        let result = table.lookup_one_2d(rows, cols);
795        let result_arr = result.to_array();
796        let cols_arr = cols.to_array();
797
798        for i in 0..16 {
799            let expected = (50 + cols_arr[i]) as u8;
800            assert_eq!(result_arr[i], expected, "Row 5, col {}", cols_arr[i]);
801        }
802    }
803
804    #[test]
805    fn test_lookup_one_2d_same_col() {
806        let table_data = create_2d_test_table();
807        let table = Table64::new(&table_data);
808
809        // All from column 3
810        let rows = u8x16::from([0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]);
811        let cols = u8x16::splat(3);
812
813        let result = table.lookup_one_2d(rows, cols);
814        let result_arr = result.to_array();
815
816        // Column 3: 3, 13, 23, 33, 43, 53, 63, 73
817        for i in 0..8 {
818            let expected = (i * 10 + 3) as u8;
819            assert_eq!(result_arr[i], expected, "Row {}, col 3", i);
820            assert_eq!(result_arr[8 + i], expected, "Row {}, col 3 (second half)", i);
821        }
822    }
823
824    #[test]
825    fn test_lookup_one_2d_matches_lookup_one() {
826        let table_data = create_2d_test_table();
827        let table = Table64::new(&table_data);
828
829        // Random (row, col) pairs
830        let rows = u8x16::from([0, 3, 7, 2, 5, 1, 6, 4, 7, 0, 3, 5, 2, 6, 1, 4]);
831        let cols = u8x16::from([5, 2, 0, 7, 3, 6, 1, 4, 7, 0, 4, 2, 6, 3, 5, 1]);
832
833        // Compute expected indices manually
834        let rows_arr = rows.to_array();
835        let cols_arr = cols.to_array();
836        let mut expected_idx = [0u8; 16];
837        for i in 0..16 {
838            expected_idx[i] = rows_arr[i] * 8 + cols_arr[i];
839        }
840
841        let result_2d = table.lookup_one_2d(rows, cols);
842        let result_1d = table.lookup_one(u8x16::from(expected_idx));
843
844        assert_eq!(
845            result_2d.to_array(),
846            result_1d.to_array(),
847            "lookup_one_2d should match lookup_one with computed indices"
848        );
849    }
850
851    // ==================== Table2dU8xU8 Tests ====================
852
853    /// Create a test table where value = row * 10 + col
854    fn create_table2d_test_data(num_rows: usize, num_cols: usize) -> Vec<u8> {
855        let mut data = Vec::with_capacity(num_rows * num_cols);
856        for r in 0..num_rows {
857            for c in 0..num_cols {
858                data.push(((r * 10 + c) % 256) as u8);
859            }
860        }
861        data
862    }
863
864    #[test]
865    fn test_table2d_from_flat_basic() {
866        let data = create_table2d_test_data(16, 16);
867        let table = Table2dU8xU8::from_flat(&data, 16);
868
869        println!("\n{:?}", table);
870        assert_eq!(table.num_rows(), 16);
871        assert_eq!(table.num_cols(), 16);
872        assert_eq!(table.len(), 256);
873    }
874
875    #[test]
876    fn test_table2d_from_2d() {
877        let row0: &[u8] = &[0, 1, 2, 3];
878        let row1: &[u8] = &[10, 11, 12, 13];
879        let row2: &[u8] = &[20, 21, 22, 23];
880        let matrix: &[&[u8]] = &[row0, row1, row2];
881
882        let table = Table2dU8xU8::from_2d(matrix);
883
884        assert_eq!(table.num_rows(), 3);
885        assert_eq!(table.num_cols(), 4);
886        assert_eq!(table.len(), 12);
887
888        // Verify scalar lookup
889        assert_eq!(table.get(0, 0), 0);
890        assert_eq!(table.get(0, 3), 3);
891        assert_eq!(table.get(1, 0), 10);
892        assert_eq!(table.get(2, 3), 23);
893    }
894
895    #[test]
896    fn test_table2d_lookup_one_basic() {
897        let data = create_table2d_test_data(16, 16);
898        let table = Table2dU8xU8::from_flat(&data, 16);
899
900        // Look up row 0, cols 0..15
901        let rows = u8x16::splat(0);
902        let cols = u8x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
903
904        let result = table.lookup_one(rows, cols);
905        let result_arr = result.to_array();
906
907        // Row 0: values are 0, 1, 2, ..., 15
908        for i in 0..16 {
909            assert_eq!(result_arr[i], i as u8, "Row 0, col {}", i);
910        }
911    }
912
913    #[test]
914    fn test_table2d_lookup_one_different_rows() {
915        let data = create_table2d_test_data(16, 16);
916        let table = Table2dU8xU8::from_flat(&data, 16);
917
918        // Look up different rows, same column
919        let rows = u8x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
920        let cols = u8x16::splat(5);
921
922        let result = table.lookup_one(rows, cols);
923        let result_arr = result.to_array();
924
925        // Column 5: values are 5, 15, 25, 35, ... (row * 10 + 5)
926        for i in 0..16 {
927            let expected = ((i * 10 + 5) % 256) as u8;
928            assert_eq!(result_arr[i], expected, "Row {}, col 5", i);
929        }
930    }
931
932    #[test]
933    fn test_table2d_lookup_one_scattered() {
934        let data = create_table2d_test_data(16, 16);
935        let table = Table2dU8xU8::from_flat(&data, 16);
936
937        // Scattered lookups
938        let rows = u8x16::from([0, 5, 10, 15, 3, 8, 12, 1, 7, 14, 2, 9, 4, 11, 6, 13]);
939        let cols = u8x16::from([0, 15, 5, 10, 3, 8, 12, 1, 7, 14, 2, 9, 4, 11, 6, 13]);
940
941        let result = table.lookup_one(rows, cols);
942        let result_arr = result.to_array();
943        let rows_arr = rows.to_array();
944        let cols_arr = cols.to_array();
945
946        for i in 0..16 {
947            let expected = ((rows_arr[i] as usize * 10 + cols_arr[i] as usize) % 256) as u8;
948            assert_eq!(
949                result_arr[i], expected,
950                "Mismatch at lane {}: row={}, col={}, expected={}, got={}",
951                i, rows_arr[i], cols_arr[i], expected, result_arr[i]
952            );
953        }
954    }
955
956    #[test]
957    fn test_table2d_lookup_matches_scalar() {
958        let data = create_table2d_test_data(32, 20);
959        let table = Table2dU8xU8::from_flat(&data, 20);
960
961        let rows = u8x16::from([0, 5, 10, 15, 20, 25, 30, 31, 1, 6, 11, 16, 21, 26, 28, 29]);
962        let cols = u8x16::from([0, 5, 10, 15, 19, 0, 5, 10, 1, 6, 11, 16, 18, 1, 6, 11]);
963
964        let result = table.lookup_one(rows, cols);
965        let result_arr = result.to_array();
966        let rows_arr = rows.to_array();
967        let cols_arr = cols.to_array();
968
969        // Verify against scalar get()
970        for i in 0..16 {
971            let expected = table.get(rows_arr[i], cols_arr[i]);
972            assert_eq!(
973                result_arr[i], expected,
974                "Mismatch at lane {}: SIMD={}, scalar={}",
975                i, result_arr[i], expected
976            );
977        }
978    }
979
980    #[test]
981    fn test_table2d_large_table() {
982        // 256 × 256 = 64K entries (maximum size)
983        let mut data = vec![0u8; 65536];
984        for r in 0..256 {
985            for c in 0..256 {
986                data[r * 256 + c] = (r ^ c) as u8; // XOR pattern
987            }
988        }
989        let table = Table2dU8xU8::from_flat(&data, 256);
990
991        assert_eq!(table.num_rows(), 256);
992        assert_eq!(table.num_cols(), 256);
993
994        // Test some lookups
995        let rows = u8x16::from([0, 255, 128, 64, 32, 16, 8, 4, 2, 1, 100, 200, 50, 150, 75, 175]);
996        let cols = u8x16::from([255, 0, 128, 64, 32, 16, 8, 4, 2, 1, 50, 100, 200, 75, 175, 150]);
997
998        let result = table.lookup_one(rows, cols);
999        let result_arr = result.to_array();
1000        let rows_arr = rows.to_array();
1001        let cols_arr = cols.to_array();
1002
1003        for i in 0..16 {
1004            let expected = rows_arr[i] ^ cols_arr[i];
1005            assert_eq!(result_arr[i], expected, "XOR mismatch at lane {}", i);
1006        }
1007    }
1008
1009    #[test]
1010    fn test_table2d_non_power_of_two_cols() {
1011        // Test with 17 columns (not power of 2)
1012        let data = create_table2d_test_data(10, 17);
1013        let table = Table2dU8xU8::from_flat(&data, 17);
1014
1015        assert_eq!(table.num_rows(), 10);
1016        assert_eq!(table.num_cols(), 17);
1017
1018        let rows = u8x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5]);
1019        let cols = u8x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 16, 15, 14, 13, 12, 11]);
1020
1021        let result = table.lookup_one(rows, cols);
1022        let result_arr = result.to_array();
1023        let rows_arr = rows.to_array();
1024        let cols_arr = cols.to_array();
1025
1026        for i in 0..16 {
1027            let expected = table.get(rows_arr[i], cols_arr[i]);
1028            assert_eq!(result_arr[i], expected, "Mismatch at lane {}", i);
1029        }
1030    }
1031}
1032