1use crate::simd_gather::gather_u32index_u8;
36use crate::wide_utils::WideUtilsExt;
37use std::fmt;
38use wide::{u8x16, u16x16, u32x16};
39
40#[cfg(target_arch = "aarch64")]
41use core::arch::aarch64::{uint8x16x4_t, vld1q_u8, vqtbl4q_u8, vst1q_u8};
42
43#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
44use core::arch::x86_64::{
45 __m128i, __m512i, _mm_loadu_si128, _mm_storeu_si128, _mm512_castsi128_si512,
46 _mm512_castsi512_si128, _mm512_loadu_si512, _mm512_permutexvar_epi8, _mm512_storeu_si512,
47};
48
49#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
50use std::arch::is_x86_feature_detected as det;
51
52pub struct Table64 {
76 #[cfg(target_arch = "aarch64")]
77 neon_tbl: uint8x16x4_t,
78
79 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
80 bytes: [u8; 64],
81
82 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
83 zmm: Option<__m512i>, }
85
86impl Table64 {
87 #[inline]
88 pub fn new(table: &[u8; 64]) -> Self {
89 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
90 {
91 let zmm = if is_x86_avx512_vbmi() {
92 unsafe {
93 let z = _mm512_loadu_si512(table.as_ptr() as *const _);
94 Some(z)
95 }
96 } else {
97 None
98 };
99
100 Self { bytes: *table, zmm }
101 }
102
103 #[cfg(target_arch = "aarch64")]
104 {
105 Self {
106 neon_tbl: unsafe {
107 let t0 = vld1q_u8(table.as_ptr());
108 let t1 = vld1q_u8(table.as_ptr().add(16));
109 let t2 = vld1q_u8(table.as_ptr().add(32));
110 let t3 = vld1q_u8(table.as_ptr().add(48));
111 uint8x16x4_t(t0, t1, t2, t3)
112 },
113 }
114 }
115 }
116
117 #[inline]
120 pub fn lookup_one(&self, idx: u8x16) -> u8x16 {
121 #[cfg(target_arch = "aarch64")]
122 unsafe {
123 let i = vld1q_u8(idx.as_array().as_ptr());
124 let r = vqtbl4q_u8(self.neon_tbl, i);
125 let mut out = [0u8; 16];
126 vst1q_u8(out.as_mut_ptr(), r);
127 u8x16::from(out)
128 }
129
130 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
131 {
132 if let Some(tzmm) = self.zmm {
133 unsafe {
134 let iv_128 = _mm_loadu_si128(idx.as_array().as_ptr() as *const __m128i);
136 let iv = _mm512_castsi128_si512(iv_128);
138 let rv = _mm512_permutexvar_epi8(iv, tzmm);
140 let rv_128 = _mm512_castsi512_si128(rv);
142 let mut result = [0u8; 16];
144 _mm_storeu_si128(result.as_mut_ptr() as *mut __m128i, rv_128);
145 u8x16::from(result)
146 }
147 } else {
148 scalar_lookup_1x16(&self.bytes, idx)
149 }
150 }
151
152 #[cfg(not(any(target_arch = "aarch64", target_arch = "x86", target_arch = "x86_64")))]
153 compile_error!(
154 "Table64::lookup_one is implemented for aarch64 (NEON) and x86/x86_64 (AVX-512VBMI)."
155 );
156 }
157
158 #[inline]
178 pub fn lookup_one_2d(&self, rows: u8x16, cols: u8x16) -> u8x16 {
179 debug_assert!(
180 rows.to_array().iter().all(|&r| r < 8),
181 "All row indices must be < 8"
182 );
183 debug_assert!(
184 cols.to_array().iter().all(|&c| c < 8),
185 "All column indices must be < 8"
186 );
187
188 let idx = rows.double().double().double() + cols;
192 self.lookup_one(idx)
193 }
194
195 #[inline]
198 fn as_bytes(&self) -> [u8; 64] {
199 #[cfg(target_arch = "aarch64")]
200 {
201 unsafe {
202 let mut bytes = [0u8; 64];
203 vst1q_u8(bytes.as_mut_ptr(), self.neon_tbl.0);
204 vst1q_u8(bytes.as_mut_ptr().add(16), self.neon_tbl.1);
205 vst1q_u8(bytes.as_mut_ptr().add(32), self.neon_tbl.2);
206 vst1q_u8(bytes.as_mut_ptr().add(48), self.neon_tbl.3);
207 bytes
208 }
209 }
210
211 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
212 {
213 self.bytes
214 }
215 }
216
217 #[inline]
221 pub fn lookup(&self, idx: &[u8x16], out: &mut [u8x16]) {
222 assert_eq!(idx.len(), out.len());
223
224 #[cfg(target_arch = "aarch64")]
225 unsafe {
226 let idx_bytes = idx.as_ptr() as *const u8;
228 let out_bytes = out.as_mut_ptr() as *mut u8;
229
230 for b in 0..idx.len() {
231 let i_ptr = idx_bytes.add(b * 16);
232 let o_ptr = out_bytes.add(b * 16);
233
234 let i = vld1q_u8(i_ptr);
235 let r = vqtbl4q_u8(self.neon_tbl, i); vst1q_u8(o_ptr, r);
237 }
238 }
239
240 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
241 unsafe {
242 let mut i = 0usize;
243 if let Some(tzmm) = self.zmm {
244 let idx_bytes = idx.as_ptr() as *const u8;
246 let out_bytes = out.as_mut_ptr() as *mut u8;
247
248 while i + 4 <= idx.len() {
249 let off = i * 16;
250 let iv = _mm512_loadu_si512(idx_bytes.add(off) as *const __m512i);
251 let rv = _mm512_permutexvar_epi8(iv, tzmm);
252 _mm512_storeu_si512(out_bytes.add(off) as *mut __m512i, rv);
253 i += 4;
254 }
255 }
256
257 for k in i..idx.len() {
259 out[k] = scalar_lookup_1x16(&self.bytes, idx[k]);
260 }
261 }
262
263 #[cfg(not(any(target_arch = "aarch64", target_arch = "x86", target_arch = "x86_64")))]
264 compile_error!(
265 "Table64::lookup is implemented for aarch64 (NEON) and x86/x86_64 (AVX-512VBMI)."
266 );
267 }
268}
269
270impl Clone for Table64 {
271 fn clone(&self) -> Self {
272 let bytes = self.as_bytes();
273 Self::new(&bytes)
274 }
275}
276
277impl Default for Table64 {
278 fn default() -> Self {
279 Self::new(&[0u8; 64])
280 }
281}
282
283impl fmt::Debug for Table64 {
284 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
285 let bytes = self.as_bytes();
286 writeln!(f, "Table64 {{")?;
287 writeln!(f, " col 0 col 1 col 2 col 3 col 4 col 5 col 6 col 7")?;
288 for row in 0..8 {
289 write!(f, "row {}: ", row)?;
290 for col in 0..8 {
291 let idx = row * 8 + col;
292 write!(f, "{:5} ", bytes[idx])?;
293 }
294 writeln!(f)?;
295 }
296 write!(f, "}}")
297 }
298}
299
300#[derive(Clone, Default)]
336pub struct Table2dU8xU8 {
337 data: Vec<u8>,
338 num_cols: u16,
339}
340
341impl Table2dU8xU8 {
342 #[inline]
355 pub fn from_flat(data: &[u8], num_cols: usize) -> Self {
356 assert!(num_cols > 0 && num_cols <= 256, "num_cols must be 1..=256");
357 assert!(data.len() % num_cols == 0, "data length must be multiple of num_cols");
358 assert!(data.len() <= 65536, "data length must be <= 65536 (64K entries)");
359
360 Self {
361 data: data.to_vec(),
362 num_cols: num_cols as u16,
363 }
364 }
365
366 #[inline]
375 pub fn from_2d(matrix: &[&[u8]]) -> Self {
376 assert!(!matrix.is_empty(), "matrix cannot be empty");
377 let num_cols = matrix[0].len();
378 assert!(num_cols > 0 && num_cols <= 256, "num_cols must be 1..=256");
379 assert!(matrix.iter().all(|row| row.len() == num_cols), "all rows must have same length");
380 assert!(matrix.len() * num_cols <= 65536, "total size must be <= 65536");
381
382 let mut data = Vec::with_capacity(matrix.len() * num_cols);
383 for row in matrix {
384 data.extend_from_slice(row);
385 }
386
387 Self {
388 data,
389 num_cols: num_cols as u16,
390 }
391 }
392
393 #[inline]
395 pub fn num_cols(&self) -> usize {
396 self.num_cols as usize
397 }
398
399 #[inline]
401 pub fn num_rows(&self) -> usize {
402 self.data.len() / self.num_cols as usize
403 }
404
405 #[inline]
407 pub fn len(&self) -> usize {
408 self.data.len()
409 }
410
411 #[inline]
413 pub fn is_empty(&self) -> bool {
414 self.data.is_empty()
415 }
416
417 #[inline]
429 pub fn lookup_one(&self, rows: u8x16, cols: u8x16) -> u8x16 {
430 let rows_u16: u16x16 = u16x16::from(rows);
432 let cols_u16: u16x16 = u16x16::from(cols);
433 let num_cols_u16 = u16x16::splat(self.num_cols);
434
435 let indices_u16 = rows_u16 * num_cols_u16 + cols_u16;
437
438 let indices_u32: u32x16 = u32x16::from(indices_u16);
440
441 #[cfg(debug_assertions)]
443 {
444 let idx_arr = indices_u32.to_array();
445 for (i, &idx) in idx_arr.iter().enumerate() {
446 debug_assert!(
447 (idx as usize) < self.data.len(),
448 "Index out of bounds at lane {}: {} >= {}",
449 i, idx, self.data.len()
450 );
451 }
452 }
453
454 gather_u32index_u8(indices_u32, &self.data, 1)
456 }
457
458 #[inline]
463 pub fn get(&self, row: u8, col: u8) -> u8 {
464 let index = (row as usize) * (self.num_cols as usize) + (col as usize);
465 self.data[index]
466 }
467}
468
469impl fmt::Debug for Table2dU8xU8 {
470 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
471 let num_rows = self.num_rows();
472 let num_cols = self.num_cols as usize;
473
474 writeln!(f, "Table2dU8xU8 {{")?;
475 writeln!(f, " dimensions: {} rows × {} cols", num_rows, num_cols)?;
476
477 if self.data.is_empty() {
478 return write!(f, " (empty)}}");
479 }
480
481 const MAX_DISPLAY_ROWS: usize = 20;
483 const MAX_DISPLAY_COLS: usize = 20;
484
485 let display_rows = num_rows.min(MAX_DISPLAY_ROWS);
486 let display_cols = num_cols.min(MAX_DISPLAY_COLS);
487 let show_row_ellipsis = num_rows > MAX_DISPLAY_ROWS;
488 let show_col_ellipsis = num_cols > MAX_DISPLAY_COLS;
489
490 write!(f, " ")?;
492 for col in 0..display_cols {
493 write!(f, " col{:3}", col)?;
494 }
495 if show_col_ellipsis {
496 write!(f, " ...")?;
497 }
498 writeln!(f)?;
499
500 for row in 0..display_rows {
502 write!(f, " row{:3}:", row)?;
503 for col in 0..display_cols {
504 let idx = row * num_cols + col;
505 write!(f, "{:5}", self.data[idx])?;
506 }
507 if show_col_ellipsis {
508 write!(f, " ...")?;
509 }
510 writeln!(f)?;
511 }
512
513 if show_row_ellipsis {
514 writeln!(f, " ...")?;
515 }
516
517 write!(f, "}}")
518 }
519}
520
521#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
526#[inline]
527fn is_x86_avx512_vbmi() -> bool {
528 det!("avx512bw") && det!("avx512vbmi")
529}
530
531#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
534#[inline]
535fn scalar_lookup_1x16(table: &[u8; 64], idx: u8x16) -> u8x16 {
536 let i = idx.to_array();
537 debug_assert!(i.iter().all(|&x| x < 64));
538 let out = [
539 table[i[0] as usize],
540 table[i[1] as usize],
541 table[i[2] as usize],
542 table[i[3] as usize],
543 table[i[4] as usize],
544 table[i[5] as usize],
545 table[i[6] as usize],
546 table[i[7] as usize],
547 table[i[8] as usize],
548 table[i[9] as usize],
549 table[i[10] as usize],
550 table[i[11] as usize],
551 table[i[12] as usize],
552 table[i[13] as usize],
553 table[i[14] as usize],
554 table[i[15] as usize],
555 ];
556 u8x16::from(out)
557}
558
559#[cfg(test)]
560mod tests {
561 use super::*;
562
563 fn create_test_table() -> [u8; 64] {
564 let mut table = [0u8; 64];
565 for i in 0..64 {
566 table[i] = (i * 3 + 7) as u8; }
568 table
569 }
570
571 #[test]
572 fn test_table64_new() {
573 let table_data = create_test_table();
574 let table = Table64::new(&table_data);
575 println!("\n{:?}", table);
576 }
578
579 #[test]
580 fn test_lookup_one_basic() {
581 let table_data = create_test_table();
582 let table = Table64::new(&table_data);
583
584 let idx = u8x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
586 let result = table.lookup_one(idx);
587 let result_arr = result.to_array();
588
589 for i in 0..16 {
591 assert_eq!(
592 result_arr[i], table_data[i],
593 "Mismatch at index {}: expected {}, got {}",
594 i, table_data[i], result_arr[i]
595 );
596 }
597 }
598
599 #[test]
600 fn test_lookup_one_scattered_indices() {
601 let table_data = create_test_table();
602 let table = Table64::new(&table_data);
603
604 let idx = u8x16::from([0, 63, 32, 16, 48, 1, 62, 31, 15, 47, 8, 56, 4, 60, 20, 40]);
606 let result = table.lookup_one(idx);
607 let result_arr = result.to_array();
608 let idx_arr = idx.to_array();
609
610 for i in 0..16 {
611 assert_eq!(
612 result_arr[i],
613 table_data[idx_arr[i] as usize],
614 "Mismatch at position {}: idx={}, expected {}, got {}",
615 i,
616 idx_arr[i],
617 table_data[idx_arr[i] as usize],
618 result_arr[i]
619 );
620 }
621 }
622
623 #[test]
624 fn test_lookup_one_all_same_index() {
625 let table_data = create_test_table();
626 let table = Table64::new(&table_data);
627
628 let idx = u8x16::splat(42);
630 let result = table.lookup_one(idx);
631 let result_arr = result.to_array();
632
633 let expected = table_data[42];
634 for i in 0..16 {
635 assert_eq!(
636 result_arr[i], expected,
637 "All lookups should return the same value"
638 );
639 }
640 }
641
642 #[test]
643 fn test_lookup_batch() {
644 let table_data = create_test_table();
645 let table = Table64::new(&table_data);
646
647 let indices = vec![
648 u8x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]),
649 u8x16::from([16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]),
650 u8x16::from([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]),
651 u8x16::from([48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]),
652 ];
653 let mut output = vec![u8x16::splat(0); 4];
654
655 table.lookup(&indices, &mut output);
656
657 for (vec_idx, out_vec) in output.iter().enumerate() {
659 let out_arr = out_vec.to_array();
660 for lane in 0..16 {
661 let table_idx = vec_idx * 16 + lane;
662 assert_eq!(
663 out_arr[lane], table_data[table_idx],
664 "Mismatch at vec {}, lane {}: expected {}, got {}",
665 vec_idx, lane, table_data[table_idx], out_arr[lane]
666 );
667 }
668 }
669 }
670
671 #[test]
672 fn test_lookup_one_matches_lookup_batch() {
673 let table_data = create_test_table();
674 let table = Table64::new(&table_data);
675
676 let idx = u8x16::from([5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 0, 32, 63, 1]);
677
678 let single_result = table.lookup_one(idx);
680
681 let mut batch_output = vec![u8x16::splat(0); 1];
683 table.lookup(&[idx], &mut batch_output);
684
685 assert_eq!(
686 single_result.to_array(),
687 batch_output[0].to_array(),
688 "lookup_one and lookup should produce the same result"
689 );
690 }
691
692 #[test]
693 fn test_identity_table() {
694 let mut table_data = [0u8; 64];
696 for i in 0..64 {
697 table_data[i] = i as u8;
698 }
699 let table = Table64::new(&table_data);
700
701 let idx = u8x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
702 let result = table.lookup_one(idx);
703
704 assert_eq!(idx.to_array(), result.to_array(), "Identity table should return input indices");
705 }
706
707 fn create_2d_test_table() -> [u8; 64] {
712 let mut table = [0u8; 64];
713 for row in 0..8 {
714 for col in 0..8 {
715 table[row * 8 + col] = (row * 10 + col) as u8;
716 }
717 }
718 table
719 }
720
721 #[test]
722 fn test_lookup_one_2d_basic() {
723 let table_data = create_2d_test_table();
724 let table = Table64::new(&table_data);
725
726 let rows = u8x16::from([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]);
728 let cols = u8x16::from([0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]);
729
730 let result = table.lookup_one_2d(rows, cols);
731 let result_arr = result.to_array();
732
733 for col in 0..8 {
735 assert_eq!(result_arr[col], col as u8, "Row 0, col {}", col);
736 }
737 for col in 0..8 {
739 assert_eq!(result_arr[8 + col], (10 + col) as u8, "Row 1, col {}", col);
740 }
741 }
742
743 #[test]
744 fn test_lookup_one_2d_diagonal() {
745 let table_data = create_2d_test_table();
746 let table = Table64::new(&table_data);
747
748 let rows = u8x16::from([0, 1, 2, 3, 4, 5, 6, 7, 7, 6, 5, 4, 3, 2, 1, 0]);
750 let cols = u8x16::from([0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]);
751
752 let result = table.lookup_one_2d(rows, cols);
753 let result_arr = result.to_array();
754
755 for i in 0..8 {
757 let expected = (i * 10 + i) as u8; assert_eq!(result_arr[i], expected, "Main diagonal position {}", i);
759 }
760
761 let expected_anti = [70, 61, 52, 43, 34, 25, 16, 7u8];
763 for i in 0..8 {
764 assert_eq!(result_arr[8 + i], expected_anti[i], "Anti-diagonal position {}", i);
765 }
766 }
767
768 #[test]
769 fn test_lookup_one_2d_corners() {
770 let table_data = create_2d_test_table();
771 let table = Table64::new(&table_data);
772
773 let rows = u8x16::from([0, 0, 7, 7, 0, 0, 7, 7, 0, 0, 7, 7, 0, 0, 7, 7]);
775 let cols = u8x16::from([0, 7, 0, 7, 0, 7, 0, 7, 0, 7, 0, 7, 0, 7, 0, 7]);
776
777 let result = table.lookup_one_2d(rows, cols);
778 let result_arr = result.to_array();
779
780 let expected = [0u8, 7, 70, 77, 0, 7, 70, 77, 0, 7, 70, 77, 0, 7, 70, 77];
782 assert_eq!(result_arr, expected, "Corner lookups");
783 }
784
785 #[test]
786 fn test_lookup_one_2d_same_row() {
787 let table_data = create_2d_test_table();
788 let table = Table64::new(&table_data);
789
790 let rows = u8x16::splat(5);
792 let cols = u8x16::from([0, 1, 2, 3, 4, 5, 6, 7, 7, 6, 5, 4, 3, 2, 1, 0]);
793
794 let result = table.lookup_one_2d(rows, cols);
795 let result_arr = result.to_array();
796 let cols_arr = cols.to_array();
797
798 for i in 0..16 {
799 let expected = (50 + cols_arr[i]) as u8;
800 assert_eq!(result_arr[i], expected, "Row 5, col {}", cols_arr[i]);
801 }
802 }
803
804 #[test]
805 fn test_lookup_one_2d_same_col() {
806 let table_data = create_2d_test_table();
807 let table = Table64::new(&table_data);
808
809 let rows = u8x16::from([0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]);
811 let cols = u8x16::splat(3);
812
813 let result = table.lookup_one_2d(rows, cols);
814 let result_arr = result.to_array();
815
816 for i in 0..8 {
818 let expected = (i * 10 + 3) as u8;
819 assert_eq!(result_arr[i], expected, "Row {}, col 3", i);
820 assert_eq!(result_arr[8 + i], expected, "Row {}, col 3 (second half)", i);
821 }
822 }
823
824 #[test]
825 fn test_lookup_one_2d_matches_lookup_one() {
826 let table_data = create_2d_test_table();
827 let table = Table64::new(&table_data);
828
829 let rows = u8x16::from([0, 3, 7, 2, 5, 1, 6, 4, 7, 0, 3, 5, 2, 6, 1, 4]);
831 let cols = u8x16::from([5, 2, 0, 7, 3, 6, 1, 4, 7, 0, 4, 2, 6, 3, 5, 1]);
832
833 let rows_arr = rows.to_array();
835 let cols_arr = cols.to_array();
836 let mut expected_idx = [0u8; 16];
837 for i in 0..16 {
838 expected_idx[i] = rows_arr[i] * 8 + cols_arr[i];
839 }
840
841 let result_2d = table.lookup_one_2d(rows, cols);
842 let result_1d = table.lookup_one(u8x16::from(expected_idx));
843
844 assert_eq!(
845 result_2d.to_array(),
846 result_1d.to_array(),
847 "lookup_one_2d should match lookup_one with computed indices"
848 );
849 }
850
851 fn create_table2d_test_data(num_rows: usize, num_cols: usize) -> Vec<u8> {
855 let mut data = Vec::with_capacity(num_rows * num_cols);
856 for r in 0..num_rows {
857 for c in 0..num_cols {
858 data.push(((r * 10 + c) % 256) as u8);
859 }
860 }
861 data
862 }
863
864 #[test]
865 fn test_table2d_from_flat_basic() {
866 let data = create_table2d_test_data(16, 16);
867 let table = Table2dU8xU8::from_flat(&data, 16);
868
869 println!("\n{:?}", table);
870 assert_eq!(table.num_rows(), 16);
871 assert_eq!(table.num_cols(), 16);
872 assert_eq!(table.len(), 256);
873 }
874
875 #[test]
876 fn test_table2d_from_2d() {
877 let row0: &[u8] = &[0, 1, 2, 3];
878 let row1: &[u8] = &[10, 11, 12, 13];
879 let row2: &[u8] = &[20, 21, 22, 23];
880 let matrix: &[&[u8]] = &[row0, row1, row2];
881
882 let table = Table2dU8xU8::from_2d(matrix);
883
884 assert_eq!(table.num_rows(), 3);
885 assert_eq!(table.num_cols(), 4);
886 assert_eq!(table.len(), 12);
887
888 assert_eq!(table.get(0, 0), 0);
890 assert_eq!(table.get(0, 3), 3);
891 assert_eq!(table.get(1, 0), 10);
892 assert_eq!(table.get(2, 3), 23);
893 }
894
895 #[test]
896 fn test_table2d_lookup_one_basic() {
897 let data = create_table2d_test_data(16, 16);
898 let table = Table2dU8xU8::from_flat(&data, 16);
899
900 let rows = u8x16::splat(0);
902 let cols = u8x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
903
904 let result = table.lookup_one(rows, cols);
905 let result_arr = result.to_array();
906
907 for i in 0..16 {
909 assert_eq!(result_arr[i], i as u8, "Row 0, col {}", i);
910 }
911 }
912
913 #[test]
914 fn test_table2d_lookup_one_different_rows() {
915 let data = create_table2d_test_data(16, 16);
916 let table = Table2dU8xU8::from_flat(&data, 16);
917
918 let rows = u8x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
920 let cols = u8x16::splat(5);
921
922 let result = table.lookup_one(rows, cols);
923 let result_arr = result.to_array();
924
925 for i in 0..16 {
927 let expected = ((i * 10 + 5) % 256) as u8;
928 assert_eq!(result_arr[i], expected, "Row {}, col 5", i);
929 }
930 }
931
932 #[test]
933 fn test_table2d_lookup_one_scattered() {
934 let data = create_table2d_test_data(16, 16);
935 let table = Table2dU8xU8::from_flat(&data, 16);
936
937 let rows = u8x16::from([0, 5, 10, 15, 3, 8, 12, 1, 7, 14, 2, 9, 4, 11, 6, 13]);
939 let cols = u8x16::from([0, 15, 5, 10, 3, 8, 12, 1, 7, 14, 2, 9, 4, 11, 6, 13]);
940
941 let result = table.lookup_one(rows, cols);
942 let result_arr = result.to_array();
943 let rows_arr = rows.to_array();
944 let cols_arr = cols.to_array();
945
946 for i in 0..16 {
947 let expected = ((rows_arr[i] as usize * 10 + cols_arr[i] as usize) % 256) as u8;
948 assert_eq!(
949 result_arr[i], expected,
950 "Mismatch at lane {}: row={}, col={}, expected={}, got={}",
951 i, rows_arr[i], cols_arr[i], expected, result_arr[i]
952 );
953 }
954 }
955
956 #[test]
957 fn test_table2d_lookup_matches_scalar() {
958 let data = create_table2d_test_data(32, 20);
959 let table = Table2dU8xU8::from_flat(&data, 20);
960
961 let rows = u8x16::from([0, 5, 10, 15, 20, 25, 30, 31, 1, 6, 11, 16, 21, 26, 28, 29]);
962 let cols = u8x16::from([0, 5, 10, 15, 19, 0, 5, 10, 1, 6, 11, 16, 18, 1, 6, 11]);
963
964 let result = table.lookup_one(rows, cols);
965 let result_arr = result.to_array();
966 let rows_arr = rows.to_array();
967 let cols_arr = cols.to_array();
968
969 for i in 0..16 {
971 let expected = table.get(rows_arr[i], cols_arr[i]);
972 assert_eq!(
973 result_arr[i], expected,
974 "Mismatch at lane {}: SIMD={}, scalar={}",
975 i, result_arr[i], expected
976 );
977 }
978 }
979
980 #[test]
981 fn test_table2d_large_table() {
982 let mut data = vec![0u8; 65536];
984 for r in 0..256 {
985 for c in 0..256 {
986 data[r * 256 + c] = (r ^ c) as u8; }
988 }
989 let table = Table2dU8xU8::from_flat(&data, 256);
990
991 assert_eq!(table.num_rows(), 256);
992 assert_eq!(table.num_cols(), 256);
993
994 let rows = u8x16::from([0, 255, 128, 64, 32, 16, 8, 4, 2, 1, 100, 200, 50, 150, 75, 175]);
996 let cols = u8x16::from([255, 0, 128, 64, 32, 16, 8, 4, 2, 1, 50, 100, 200, 75, 175, 150]);
997
998 let result = table.lookup_one(rows, cols);
999 let result_arr = result.to_array();
1000 let rows_arr = rows.to_array();
1001 let cols_arr = cols.to_array();
1002
1003 for i in 0..16 {
1004 let expected = rows_arr[i] ^ cols_arr[i];
1005 assert_eq!(result_arr[i], expected, "XOR mismatch at lane {}", i);
1006 }
1007 }
1008
1009 #[test]
1010 fn test_table2d_non_power_of_two_cols() {
1011 let data = create_table2d_test_data(10, 17);
1013 let table = Table2dU8xU8::from_flat(&data, 17);
1014
1015 assert_eq!(table.num_rows(), 10);
1016 assert_eq!(table.num_cols(), 17);
1017
1018 let rows = u8x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5]);
1019 let cols = u8x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 16, 15, 14, 13, 12, 11]);
1020
1021 let result = table.lookup_one(rows, cols);
1022 let result_arr = result.to_array();
1023 let rows_arr = rows.to_array();
1024 let cols_arr = cols.to_array();
1025
1026 for i in 0..16 {
1027 let expected = table.get(rows_arr[i], cols_arr[i]);
1028 assert_eq!(result_arr[i], expected, "Mismatch at lane {}", i);
1029 }
1030 }
1031}
1032