1#![allow(unsafe_op_in_unsafe_fn)]
32use std::sync::OnceLock;
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub enum SimdCapability {
65 Scalar,
67 Sse41,
69 Avx2,
71 Avx512,
73 Neon,
75}
76
77impl SimdCapability {
78 pub fn detect() -> Self {
80 #[cfg(target_arch = "x86_64")]
81 {
82 if is_x86_feature_detected!("avx512f") {
83 return Self::Avx512;
84 }
85 if is_x86_feature_detected!("avx2") {
86 return Self::Avx2;
87 }
88 if is_x86_feature_detected!("sse4.1") {
89 return Self::Sse41;
90 }
91 }
92
93 #[cfg(target_arch = "aarch64")]
94 {
95 Self::Neon
96 }
97
98 #[cfg(not(target_arch = "aarch64"))]
99 Self::Scalar
100 }
101
102 pub fn width(&self) -> usize {
104 match self {
105 Self::Scalar => 1,
106 Self::Sse41 | Self::Neon => 4,
107 Self::Avx2 => 8,
108 Self::Avx512 => 16,
109 }
110 }
111}
112
113static CAPABILITY: OnceLock<SimdCapability> = OnceLock::new();
115
116pub fn simd_capability() -> SimdCapability {
118 *CAPABILITY.get_or_init(SimdCapability::detect)
119}
120
121#[derive(Debug, Clone, Copy)]
123pub struct HadamardKernel {
124 capability: SimdCapability,
125}
126
127impl HadamardKernel {
128 pub fn detect() -> Self {
130 Self {
131 capability: simd_capability(),
132 }
133 }
134
135 pub fn with_capability(capability: SimdCapability) -> Self {
137 Self { capability }
138 }
139
140 #[inline]
142 pub fn transform(&self, data: &mut [f32]) {
143 let n = data.len();
144
145 if n == 0 || !n.is_power_of_two() {
146 return;
147 }
148
149 match self.capability {
150 #[cfg(target_arch = "x86_64")]
151 SimdCapability::Avx512 => unsafe { hadamard_avx512(data) },
152 #[cfg(target_arch = "x86_64")]
153 SimdCapability::Avx2 => unsafe { hadamard_avx2(data) },
154 #[cfg(target_arch = "x86_64")]
155 SimdCapability::Sse41 => unsafe { hadamard_sse41(data) },
156 #[cfg(target_arch = "aarch64")]
157 SimdCapability::Neon => unsafe { hadamard_neon(data) },
158 _ => hadamard_scalar(data),
159 }
160 }
161
162 pub fn transform_batch(&self, flat_data: &mut [f32], dim: usize) {
164 if dim == 0 || !dim.is_power_of_two() {
165 return;
166 }
167
168 let num_vectors = flat_data.len() / dim;
169
170 for i in 0..num_vectors {
171 let start = i * dim;
172 let slice = &mut flat_data[start..start + dim];
173 self.transform(slice);
174 }
175 }
176
177 pub fn capability(&self) -> SimdCapability {
179 self.capability
180 }
181}
182
183impl Default for HadamardKernel {
184 fn default() -> Self {
185 Self::detect()
186 }
187}
188
189pub fn hadamard_scalar(data: &mut [f32]) {
195 let n = data.len();
196 if n == 0 || !n.is_power_of_two() {
197 return;
198 }
199
200 let mut h = 1;
202 while h < n {
203 for i in (0..n).step_by(h * 2) {
204 for j in i..(i + h) {
205 let x = data[j];
206 let y = data[j + h];
207 data[j] = x + y;
208 data[j + h] = x - y;
209 }
210 }
211 h *= 2;
212 }
213
214 let scale = 1.0 / (n as f32).sqrt();
216 for x in data.iter_mut() {
217 *x *= scale;
218 }
219}
220
221#[cfg(target_arch = "x86_64")]
226#[target_feature(enable = "avx2")]
227unsafe fn hadamard_avx2(data: &mut [f32]) {
228 use std::arch::x86_64::*;
229 unsafe {
230 let n = data.len();
231
232 if n < 8 {
234 hadamard_scalar(data);
235 return;
236 }
237
238 let mut h = 1;
240
241 while h < 8 && h < n {
243 for i in (0..n).step_by(h * 2) {
244 for j in i..(i + h) {
245 let x = *data.get_unchecked(j);
246 let y = *data.get_unchecked(j + h);
247 *data.get_unchecked_mut(j) = x + y;
248 *data.get_unchecked_mut(j + h) = x - y;
249 }
250 }
251 h *= 2;
252 }
253
254 while h < n {
256 let blocks = n / (h * 2);
257
258 for block in 0..blocks {
259 let base = block * h * 2;
260
261 for j in (0..h).step_by(8) {
263 let idx_a = base + j;
264 let idx_b = base + h + j;
265
266 let va = _mm256_loadu_ps(data.as_ptr().add(idx_a));
267 let vb = _mm256_loadu_ps(data.as_ptr().add(idx_b));
268
269 let sum = _mm256_add_ps(va, vb);
270 let diff = _mm256_sub_ps(va, vb);
271
272 _mm256_storeu_ps(data.as_mut_ptr().add(idx_a), sum);
273 _mm256_storeu_ps(data.as_mut_ptr().add(idx_b), diff);
274 }
275
276 let remainder = h % 8;
278 if remainder > 0 {
279 let start = h - remainder;
280 for j in start..h {
281 let idx_a = base + j;
282 let idx_b = base + h + j;
283 let x = *data.get_unchecked(idx_a);
284 let y = *data.get_unchecked(idx_b);
285 *data.get_unchecked_mut(idx_a) = x + y;
286 *data.get_unchecked_mut(idx_b) = x - y;
287 }
288 }
289 }
290
291 h *= 2;
292 }
293
294 let scale = 1.0 / (n as f32).sqrt();
296 let vscale = _mm256_set1_ps(scale);
297
298 let chunks = n / 8;
299 for i in 0..chunks {
300 let offset = i * 8;
301 let v = _mm256_loadu_ps(data.as_ptr().add(offset));
302 let scaled = _mm256_mul_ps(v, vscale);
303 _mm256_storeu_ps(data.as_mut_ptr().add(offset), scaled);
304 }
305
306 for i in (chunks * 8)..n {
308 *data.get_unchecked_mut(i) *= scale;
309 }
310 }
311}
312
313#[cfg(target_arch = "x86_64")]
318#[target_feature(enable = "sse4.1")]
319unsafe fn hadamard_sse41(data: &mut [f32]) {
320 use std::arch::x86_64::*;
321 unsafe {
322 let n = data.len();
323
324 if n < 4 {
325 hadamard_scalar(data);
326 return;
327 }
328
329 let mut h = 1;
331
332 while h < 4 && h < n {
334 for i in (0..n).step_by(h * 2) {
335 for j in i..(i + h) {
336 let x = *data.get_unchecked(j);
337 let y = *data.get_unchecked(j + h);
338 *data.get_unchecked_mut(j) = x + y;
339 *data.get_unchecked_mut(j + h) = x - y;
340 }
341 }
342 h *= 2;
343 }
344
345 while h < n {
347 let blocks = n / (h * 2);
348
349 for block in 0..blocks {
350 let base = block * h * 2;
351
352 for j in (0..h).step_by(4) {
353 let idx_a = base + j;
354 let idx_b = base + h + j;
355
356 let va = _mm_loadu_ps(data.as_ptr().add(idx_a));
357 let vb = _mm_loadu_ps(data.as_ptr().add(idx_b));
358
359 let sum = _mm_add_ps(va, vb);
360 let diff = _mm_sub_ps(va, vb);
361
362 _mm_storeu_ps(data.as_mut_ptr().add(idx_a), sum);
363 _mm_storeu_ps(data.as_mut_ptr().add(idx_b), diff);
364 }
365
366 let remainder = h % 4;
368 if remainder > 0 {
369 let start = h - remainder;
370 for j in start..h {
371 let idx_a = base + j;
372 let idx_b = base + h + j;
373 let x = *data.get_unchecked(idx_a);
374 let y = *data.get_unchecked(idx_b);
375 *data.get_unchecked_mut(idx_a) = x + y;
376 *data.get_unchecked_mut(idx_b) = x - y;
377 }
378 }
379 }
380
381 h *= 2;
382 }
383
384 let scale = 1.0 / (n as f32).sqrt();
386 let vscale = _mm_set1_ps(scale);
387
388 let chunks = n / 4;
389 for i in 0..chunks {
390 let offset = i * 4;
391 let v = _mm_loadu_ps(data.as_ptr().add(offset));
392 let scaled = _mm_mul_ps(v, vscale);
393 _mm_storeu_ps(data.as_mut_ptr().add(offset), scaled);
394 }
395
396 for i in (chunks * 4)..n {
397 *data.get_unchecked_mut(i) *= scale;
398 }
399 }
400}
401
402#[cfg(target_arch = "x86_64")]
407#[target_feature(enable = "avx512f")]
408unsafe fn hadamard_avx512(data: &mut [f32]) {
409 use std::arch::x86_64::*;
410 unsafe {
411 let n = data.len();
412
413 if n < 16 {
414 hadamard_avx2(data);
415 return;
416 }
417
418 let mut h = 1;
420
421 while h < 16 && h < n {
423 for i in (0..n).step_by(h * 2) {
424 for j in i..(i + h) {
425 let x = *data.get_unchecked(j);
426 let y = *data.get_unchecked(j + h);
427 *data.get_unchecked_mut(j) = x + y;
428 *data.get_unchecked_mut(j + h) = x - y;
429 }
430 }
431 h *= 2;
432 }
433
434 while h < n {
436 let blocks = n / (h * 2);
437
438 for block in 0..blocks {
439 let base = block * h * 2;
440
441 for j in (0..h).step_by(16) {
442 let idx_a = base + j;
443 let idx_b = base + h + j;
444
445 let va = _mm512_loadu_ps(data.as_ptr().add(idx_a));
446 let vb = _mm512_loadu_ps(data.as_ptr().add(idx_b));
447
448 let sum = _mm512_add_ps(va, vb);
449 let diff = _mm512_sub_ps(va, vb);
450
451 _mm512_storeu_ps(data.as_mut_ptr().add(idx_a), sum);
452 _mm512_storeu_ps(data.as_mut_ptr().add(idx_b), diff);
453 }
454
455 let remainder = h % 16;
457 if remainder > 0 {
458 let start = h - remainder;
459 for j in start..h {
460 let idx_a = base + j;
461 let idx_b = base + h + j;
462 let x = *data.get_unchecked(idx_a);
463 let y = *data.get_unchecked(idx_b);
464 *data.get_unchecked_mut(idx_a) = x + y;
465 *data.get_unchecked_mut(idx_b) = x - y;
466 }
467 }
468 }
469
470 h *= 2;
471 }
472
473 let scale = 1.0 / (n as f32).sqrt();
475 let vscale = _mm512_set1_ps(scale);
476
477 let chunks = n / 16;
478 for i in 0..chunks {
479 let offset = i * 16;
480 let v = _mm512_loadu_ps(data.as_ptr().add(offset));
481 let scaled = _mm512_mul_ps(v, vscale);
482 _mm512_storeu_ps(data.as_mut_ptr().add(offset), scaled);
483 }
484
485 for i in (chunks * 16)..n {
486 *data.get_unchecked_mut(i) *= scale;
487 }
488 }
489}
490
491#[cfg(target_arch = "aarch64")]
496#[inline]
497unsafe fn hadamard_neon(data: &mut [f32]) {
498 use std::arch::aarch64::*;
499 unsafe {
500 let n = data.len();
501
502 if n < 4 {
503 hadamard_scalar(data);
504 return;
505 }
506
507 let mut h = 1;
509
510 while h < 4 && h < n {
512 for i in (0..n).step_by(h * 2) {
513 for j in i..(i + h) {
514 let x = *data.get_unchecked(j);
515 let y = *data.get_unchecked(j + h);
516 *data.get_unchecked_mut(j) = x + y;
517 *data.get_unchecked_mut(j + h) = x - y;
518 }
519 }
520 h *= 2;
521 }
522
523 while h < n {
525 let blocks = n / (h * 2);
526
527 for block in 0..blocks {
528 let base = block * h * 2;
529
530 for j in (0..h).step_by(4) {
531 let idx_a = base + j;
532 let idx_b = base + h + j;
533
534 let va = vld1q_f32(data.as_ptr().add(idx_a));
535 let vb = vld1q_f32(data.as_ptr().add(idx_b));
536
537 let sum = vaddq_f32(va, vb);
538 let diff = vsubq_f32(va, vb);
539
540 vst1q_f32(data.as_mut_ptr().add(idx_a), sum);
541 vst1q_f32(data.as_mut_ptr().add(idx_b), diff);
542 }
543
544 let remainder = h % 4;
546 if remainder > 0 {
547 let start = h - remainder;
548 for j in start..h {
549 let idx_a = base + j;
550 let idx_b = base + h + j;
551 let x = *data.get_unchecked(idx_a);
552 let y = *data.get_unchecked(idx_b);
553 *data.get_unchecked_mut(idx_a) = x + y;
554 *data.get_unchecked_mut(idx_b) = x - y;
555 }
556 }
557 }
558
559 h *= 2;
560 }
561
562 let scale = 1.0 / (n as f32).sqrt();
564 let vscale = vdupq_n_f32(scale);
565
566 let chunks = n / 4;
567 for i in 0..chunks {
568 let offset = i * 4;
569 let v = vld1q_f32(data.as_ptr().add(offset));
570 let scaled = vmulq_f32(v, vscale);
571 vst1q_f32(data.as_mut_ptr().add(offset), scaled);
572 }
573
574 for i in (chunks * 4)..n {
575 *data.get_unchecked_mut(i) *= scale;
576 }
577 }
578}
579
580#[inline]
586pub fn hadamard_transform(data: &mut [f32]) {
587 HadamardKernel::detect().transform(data);
588}
589
590pub fn hadamard_transform_batch(flat_data: &mut [f32], dim: usize) {
592 HadamardKernel::detect().transform_batch(flat_data, dim);
593}
594
595#[cfg(test)]
600mod tests {
601 use super::*;
602
603 #[test]
604 fn test_scalar_basic() {
605 let mut data = vec![1.0, 0.0, 0.0, 0.0];
606 hadamard_scalar(&mut data);
607
608 for &x in &data {
609 assert!((x - 0.5).abs() < 0.01, "x = {}", x);
610 }
611 }
612
613 #[test]
614 fn test_scalar_identity() {
615 let original = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
617 let mut data = original.clone();
618
619 hadamard_scalar(&mut data);
620 hadamard_scalar(&mut data);
621
622 for (a, b) in original.iter().zip(data.iter()) {
623 assert!((a - b).abs() < 0.01, "a = {}, b = {}", a, b);
624 }
625 }
626
627 #[test]
628 fn test_kernel_detection() {
629 let kernel = HadamardKernel::detect();
630 let cap = kernel.capability();
631
632 #[cfg(target_arch = "x86_64")]
633 assert!(matches!(
634 cap,
635 SimdCapability::Scalar
636 | SimdCapability::Sse41
637 | SimdCapability::Avx2
638 | SimdCapability::Avx512
639 ));
640
641 #[cfg(target_arch = "aarch64")]
642 assert_eq!(cap, SimdCapability::Neon);
643 }
644
645 #[test]
646 fn test_kernel_consistency() {
647 let original = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
648
649 let mut scalar_data = original.clone();
651 hadamard_scalar(&mut scalar_data);
652
653 let mut kernel_data = original.clone();
655 hadamard_transform(&mut kernel_data);
656
657 for (a, b) in scalar_data.iter().zip(kernel_data.iter()) {
658 assert!(
659 (a - b).abs() < 1e-5,
660 "Mismatch: scalar {} vs kernel {}",
661 a,
662 b
663 );
664 }
665 }
666
667 #[test]
668 fn test_preserves_norm() {
669 let mut data: Vec<f32> = (1..=16).map(|i| i as f32).collect();
670 let original_norm: f32 = data.iter().map(|x| x * x).sum();
671
672 hadamard_transform(&mut data);
673
674 let new_norm: f32 = data.iter().map(|x| x * x).sum();
675
676 assert!(
677 (original_norm - new_norm).abs() < 0.1,
678 "Norm changed: {} -> {}",
679 original_norm,
680 new_norm
681 );
682 }
683
684 #[test]
685 fn test_batch_transform() {
686 let dim = 8;
687 let num_vectors = 10;
688 let mut flat_data: Vec<f32> = (0..(dim * num_vectors)).map(|i| i as f32 / 100.0).collect();
689
690 hadamard_transform_batch(&mut flat_data, dim);
691
692 for i in 0..num_vectors {
694 let start = i * dim;
695 let vec = &flat_data[start..start + dim];
696
697 let norm: f32 = vec.iter().map(|x| x * x).sum();
699 assert!(norm > 0.0, "Vector {} has zero norm", i);
700 }
701 }
702
703 #[test]
704 fn test_non_power_of_two() {
705 let mut data = vec![1.0, 2.0, 3.0]; let original = data.clone();
707
708 hadamard_transform(&mut data);
709
710 assert_eq!(data, original);
712 }
713
714 #[test]
715 fn test_large_dimension() {
716 let dim = 1024; let mut data: Vec<f32> = (0..dim).map(|i| (i as f32).sin()).collect();
718 let original_norm: f32 = data.iter().map(|x| x * x).sum();
719
720 hadamard_transform(&mut data);
721
722 let new_norm: f32 = data.iter().map(|x| x * x).sum();
723
724 let rel_error = (original_norm - new_norm).abs() / original_norm;
725 assert!(rel_error < 1e-5, "Norm error too large: {}", rel_error);
726 }
727}