1use ::ndarray::{Array1, ArrayView1};
64
65#[inline(always)]
67#[allow(clippy::uninit_vec)] pub fn simd_add_f32_ultra_optimized(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
69 let len = a.len();
70 assert_eq!(len, b.len(), "Arrays must have same length");
71
72 let mut result = Vec::with_capacity(len);
74 unsafe {
75 result.set_len(len);
76 }
77
78 #[cfg(target_arch = "x86_64")]
79 {
80 unsafe {
81 use std::arch::x86_64::*;
82
83 let a_ptr = a.as_slice().expect("Operation failed").as_ptr();
85 let b_ptr = b.as_slice().expect("Operation failed").as_ptr();
86 let result_ptr = result.as_mut_ptr();
87
88 if is_x86_feature_detected!("avx512f") {
89 avx512_add_f32_inner(a_ptr, b_ptr, result_ptr, len);
90 } else if is_x86_feature_detected!("avx2") {
91 avx2_add_f32_inner(a_ptr, b_ptr, result_ptr, len);
92 } else if is_x86_feature_detected!("sse") {
93 sse_add_f32_inner(a_ptr, b_ptr, result_ptr, len);
94 } else {
95 scalar_add_f32_inner(a_ptr, b_ptr, result_ptr, len);
96 }
97 }
98 }
99
100 #[cfg(target_arch = "aarch64")]
101 {
102 unsafe {
103 let a_ptr = a.as_slice().expect("Operation failed").as_ptr();
104 let b_ptr = b.as_slice().expect("Operation failed").as_ptr();
105 let result_ptr = result.as_mut_ptr();
106
107 if std::arch::is_aarch64_feature_detected!("neon") {
108 neon_add_f32_inner(a_ptr, b_ptr, result_ptr, len);
109 } else {
110 scalar_add_f32_inner(a_ptr, b_ptr, result_ptr, len);
111 }
112 }
113 }
114
115 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
116 {
117 unsafe {
118 let a_ptr = a.as_slice().expect("Operation failed").as_ptr();
119 let b_ptr = b.as_slice().expect("Operation failed").as_ptr();
120 let result_ptr = result.as_mut_ptr();
121 scalar_add_f32_inner(a_ptr, b_ptr, result_ptr, len);
122 }
123 }
124
125 Array1::from_vec(result)
126}
127
128#[cfg(target_arch = "x86_64")]
131#[inline]
132#[target_feature(enable = "avx512f")]
133unsafe fn avx512_add_f32_inner(a: *const f32, b: *const f32, result: *mut f32, len: usize) {
134 use std::arch::x86_64::*;
135
136 let mut i = 0;
137 const PREFETCH_DISTANCE: usize = 512;
138
139 let a_aligned = (a as usize) % 64 == 0;
141 let b_aligned = (b as usize) % 64 == 0;
142 let result_aligned = (result as usize) % 64 == 0;
143
144 if a_aligned && b_aligned && result_aligned {
146 while i + 64 <= len {
147 if i + PREFETCH_DISTANCE < len {
149 _mm_prefetch(a.add(i + PREFETCH_DISTANCE) as *const i8, _MM_HINT_T0);
150 _mm_prefetch(b.add(i + PREFETCH_DISTANCE) as *const i8, _MM_HINT_T0);
151 }
152
153 let a1 = _mm512_load_ps(a.add(i));
155 let b1 = _mm512_load_ps(b.add(i));
156 let a2 = _mm512_load_ps(a.add(i + 16));
157 let b2 = _mm512_load_ps(b.add(i + 16));
158 let a3 = _mm512_load_ps(a.add(i + 32));
159 let b3 = _mm512_load_ps(b.add(i + 32));
160 let a4 = _mm512_load_ps(a.add(i + 48));
161 let b4 = _mm512_load_ps(b.add(i + 48));
162
163 let r1 = _mm512_add_ps(a1, b1);
165 let r2 = _mm512_add_ps(a2, b2);
166 let r3 = _mm512_add_ps(a3, b3);
167 let r4 = _mm512_add_ps(a4, b4);
168
169 _mm512_store_ps(result.add(i), r1);
171 _mm512_store_ps(result.add(i + 16), r2);
172 _mm512_store_ps(result.add(i + 32), r3);
173 _mm512_store_ps(result.add(i + 48), r4);
174
175 i += 64;
176 }
177 }
178
179 while i + 16 <= len {
181 let a_vec = _mm512_loadu_ps(a.add(i));
182 let b_vec = _mm512_loadu_ps(b.add(i));
183 let result_vec = _mm512_add_ps(a_vec, b_vec);
184 _mm512_storeu_ps(result.add(i), result_vec);
185 i += 16;
186 }
187
188 while i < len {
190 *result.add(i) = *a.add(i) + *b.add(i);
191 i += 1;
192 }
193}
194
195#[cfg(target_arch = "x86_64")]
198#[inline]
199#[target_feature(enable = "avx2")]
200unsafe fn avx2_add_f32_inner(a: *const f32, b: *const f32, result: *mut f32, len: usize) {
201 use std::arch::x86_64::*;
202
203 let mut i = 0;
204 const PREFETCH_DISTANCE: usize = 256;
205
206 let a_aligned = (a as usize) % 32 == 0;
208 let b_aligned = (b as usize) % 32 == 0;
209 let result_aligned = (result as usize) % 32 == 0;
210
211 if a_aligned && b_aligned && result_aligned && len >= 64 {
213 while i + 64 <= len {
214 if i + PREFETCH_DISTANCE < len {
216 _mm_prefetch(a.add(i + PREFETCH_DISTANCE) as *const i8, _MM_HINT_T0);
217 _mm_prefetch(b.add(i + PREFETCH_DISTANCE) as *const i8, _MM_HINT_T0);
218 }
219
220 let a1 = _mm256_load_ps(a.add(i));
222 let b1 = _mm256_load_ps(b.add(i));
223 let a2 = _mm256_load_ps(a.add(i + 8));
224 let b2 = _mm256_load_ps(b.add(i + 8));
225 let a3 = _mm256_load_ps(a.add(i + 16));
226 let b3 = _mm256_load_ps(b.add(i + 16));
227 let a4 = _mm256_load_ps(a.add(i + 24));
228 let b4 = _mm256_load_ps(b.add(i + 24));
229 let a5 = _mm256_load_ps(a.add(i + 32));
230 let b5 = _mm256_load_ps(b.add(i + 32));
231 let a6 = _mm256_load_ps(a.add(i + 40));
232 let b6 = _mm256_load_ps(b.add(i + 40));
233 let a7 = _mm256_load_ps(a.add(i + 48));
234 let b7 = _mm256_load_ps(b.add(i + 48));
235 let a8 = _mm256_load_ps(a.add(i + 56));
236 let b8 = _mm256_load_ps(b.add(i + 56));
237
238 let r1 = _mm256_add_ps(a1, b1);
239 let r2 = _mm256_add_ps(a2, b2);
240 let r3 = _mm256_add_ps(a3, b3);
241 let r4 = _mm256_add_ps(a4, b4);
242 let r5 = _mm256_add_ps(a5, b5);
243 let r6 = _mm256_add_ps(a6, b6);
244 let r7 = _mm256_add_ps(a7, b7);
245 let r8 = _mm256_add_ps(a8, b8);
246
247 _mm256_store_ps(result.add(i), r1);
248 _mm256_store_ps(result.add(i + 8), r2);
249 _mm256_store_ps(result.add(i + 16), r3);
250 _mm256_store_ps(result.add(i + 24), r4);
251 _mm256_store_ps(result.add(i + 32), r5);
252 _mm256_store_ps(result.add(i + 40), r6);
253 _mm256_store_ps(result.add(i + 48), r7);
254 _mm256_store_ps(result.add(i + 56), r8);
255
256 i += 64;
257 }
258 }
259
260 while i + 8 <= len {
262 let a_vec = _mm256_loadu_ps(a.add(i));
263 let b_vec = _mm256_loadu_ps(b.add(i));
264 let result_vec = _mm256_add_ps(a_vec, b_vec);
265 _mm256_storeu_ps(result.add(i), result_vec);
266 i += 8;
267 }
268
269 while i < len {
271 *result.add(i) = *a.add(i) + *b.add(i);
272 i += 1;
273 }
274}
275
276#[cfg(target_arch = "x86_64")]
279#[inline]
280#[target_feature(enable = "sse")]
281unsafe fn sse_add_f32_inner(a: *const f32, b: *const f32, result: *mut f32, len: usize) {
282 use std::arch::x86_64::*;
283
284 let mut i = 0;
285
286 while i + 16 <= len {
288 let a1 = _mm_loadu_ps(a.add(i));
289 let b1 = _mm_loadu_ps(b.add(i));
290 let a2 = _mm_loadu_ps(a.add(i + 4));
291 let b2 = _mm_loadu_ps(b.add(i + 4));
292 let a3 = _mm_loadu_ps(a.add(i + 8));
293 let b3 = _mm_loadu_ps(b.add(i + 8));
294 let a4 = _mm_loadu_ps(a.add(i + 12));
295 let b4 = _mm_loadu_ps(b.add(i + 12));
296
297 let r1 = _mm_add_ps(a1, b1);
298 let r2 = _mm_add_ps(a2, b2);
299 let r3 = _mm_add_ps(a3, b3);
300 let r4 = _mm_add_ps(a4, b4);
301
302 _mm_storeu_ps(result.add(i), r1);
303 _mm_storeu_ps(result.add(i + 4), r2);
304 _mm_storeu_ps(result.add(i + 8), r3);
305 _mm_storeu_ps(result.add(i + 12), r4);
306
307 i += 16;
308 }
309
310 while i + 4 <= len {
312 let a_vec = _mm_loadu_ps(a.add(i));
313 let b_vec = _mm_loadu_ps(b.add(i));
314 let result_vec = _mm_add_ps(a_vec, b_vec);
315 _mm_storeu_ps(result.add(i), result_vec);
316 i += 4;
317 }
318
319 while i < len {
321 *result.add(i) = *a.add(i) + *b.add(i);
322 i += 1;
323 }
324}
325
326#[cfg(target_arch = "aarch64")]
329#[inline(always)]
330unsafe fn neon_add_f32_inner(a: *const f32, b: *const f32, result: *mut f32, len: usize) {
331 use std::arch::aarch64::*;
332
333 let mut i = 0;
334
335 while i + 16 <= len {
337 let a1 = vld1q_f32(a.add(i));
338 let b1 = vld1q_f32(b.add(i));
339 let a2 = vld1q_f32(a.add(i + 4));
340 let b2 = vld1q_f32(b.add(i + 4));
341 let a3 = vld1q_f32(a.add(i + 8));
342 let b3 = vld1q_f32(b.add(i + 8));
343 let a4 = vld1q_f32(a.add(i + 12));
344 let b4 = vld1q_f32(b.add(i + 12));
345
346 let r1 = vaddq_f32(a1, b1);
347 let r2 = vaddq_f32(a2, b2);
348 let r3 = vaddq_f32(a3, b3);
349 let r4 = vaddq_f32(a4, b4);
350
351 vst1q_f32(result.add(i), r1);
352 vst1q_f32(result.add(i + 4), r2);
353 vst1q_f32(result.add(i + 8), r3);
354 vst1q_f32(result.add(i + 12), r4);
355
356 i += 16;
357 }
358
359 while i + 4 <= len {
361 let a_vec = vld1q_f32(a.add(i));
362 let b_vec = vld1q_f32(b.add(i));
363 let result_vec = vaddq_f32(a_vec, b_vec);
364 vst1q_f32(result.add(i), result_vec);
365 i += 4;
366 }
367
368 while i < len {
370 *result.add(i) = *a.add(i) + *b.add(i);
371 i += 1;
372 }
373}
374
375#[inline(always)]
378unsafe fn scalar_add_f32_inner(a: *const f32, b: *const f32, result: *mut f32, len: usize) {
379 for i in 0..len {
380 *result.add(i) = *a.add(i) + *b.add(i);
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387 use ::ndarray::Array1;
388
389 #[test]
390 fn test_ultra_optimized_add() {
391 let a = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
392 let b = Array1::from_vec(vec![8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0]);
393
394 let result = simd_add_f32_ultra_optimized(&a.view(), &b.view());
395
396 for i in 0..8 {
397 assert_eq!(result[i], 9.0);
398 }
399 }
400
401 #[test]
402 fn test_large_array() {
403 let size = 10000;
404 let a = Array1::from_elem(size, 2.0f32);
405 let b = Array1::from_elem(size, 3.0f32);
406
407 let result = simd_add_f32_ultra_optimized(&a.view(), &b.view());
408
409 for i in 0..size {
410 assert_eq!(result[i], 5.0);
411 }
412 }
413}
414
415#[inline(always)]
419#[allow(clippy::uninit_vec)] pub fn simd_mul_f32_ultra_optimized(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
421 let len = a.len();
422 assert_eq!(len, b.len(), "Arrays must have same length");
423
424 let mut result = Vec::with_capacity(len);
425 unsafe {
426 result.set_len(len);
427 }
428
429 #[cfg(target_arch = "x86_64")]
430 {
431 unsafe {
432 use std::arch::x86_64::*;
433
434 let a_ptr = a.as_slice().expect("Operation failed").as_ptr();
435 let b_ptr = b.as_slice().expect("Operation failed").as_ptr();
436 let result_ptr = result.as_mut_ptr();
437
438 if is_x86_feature_detected!("avx512f") {
439 avx512_mul_f32_inner(a_ptr, b_ptr, result_ptr, len);
440 } else if is_x86_feature_detected!("avx2") {
441 avx2_mul_f32_inner(a_ptr, b_ptr, result_ptr, len);
442 } else if is_x86_feature_detected!("sse") {
443 sse_mul_f32_inner(a_ptr, b_ptr, result_ptr, len);
444 } else {
445 scalar_mul_f32_inner(a_ptr, b_ptr, result_ptr, len);
446 }
447 }
448 }
449
450 #[cfg(target_arch = "aarch64")]
451 {
452 unsafe {
453 let a_ptr = a.as_slice().expect("Operation failed").as_ptr();
454 let b_ptr = b.as_slice().expect("Operation failed").as_ptr();
455 let result_ptr = result.as_mut_ptr();
456
457 if std::arch::is_aarch64_feature_detected!("neon") {
458 neon_mul_f32_inner(a_ptr, b_ptr, result_ptr, len);
459 } else {
460 scalar_mul_f32_inner(a_ptr, b_ptr, result_ptr, len);
461 }
462 }
463 }
464
465 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
466 {
467 unsafe {
468 let a_ptr = a.as_slice().expect("Operation failed").as_ptr();
469 let b_ptr = b.as_slice().expect("Operation failed").as_ptr();
470 let result_ptr = result.as_mut_ptr();
471 scalar_mul_f32_inner(a_ptr, b_ptr, result_ptr, len);
472 }
473 }
474
475 Array1::from_vec(result)
476}
477
478#[cfg(target_arch = "x86_64")]
479#[inline]
480#[target_feature(enable = "avx512f")]
481unsafe fn avx512_mul_f32_inner(a: *const f32, b: *const f32, result: *mut f32, len: usize) {
482 use std::arch::x86_64::*;
483
484 let mut i = 0;
485 const PREFETCH_DISTANCE: usize = 512;
486
487 let a_aligned = (a as usize) % 64 == 0;
488 let b_aligned = (b as usize) % 64 == 0;
489 let result_aligned = (result as usize) % 64 == 0;
490
491 if a_aligned && b_aligned && result_aligned {
492 while i + 64 <= len {
493 if i + PREFETCH_DISTANCE < len {
494 _mm_prefetch(a.add(i + PREFETCH_DISTANCE) as *const i8, _MM_HINT_T0);
495 _mm_prefetch(b.add(i + PREFETCH_DISTANCE) as *const i8, _MM_HINT_T0);
496 }
497
498 let a1 = _mm512_load_ps(a.add(i));
499 let b1 = _mm512_load_ps(b.add(i));
500 let a2 = _mm512_load_ps(a.add(i + 16));
501 let b2 = _mm512_load_ps(b.add(i + 16));
502 let a3 = _mm512_load_ps(a.add(i + 32));
503 let b3 = _mm512_load_ps(b.add(i + 32));
504 let a4 = _mm512_load_ps(a.add(i + 48));
505 let b4 = _mm512_load_ps(b.add(i + 48));
506
507 let r1 = _mm512_mul_ps(a1, b1);
508 let r2 = _mm512_mul_ps(a2, b2);
509 let r3 = _mm512_mul_ps(a3, b3);
510 let r4 = _mm512_mul_ps(a4, b4);
511
512 _mm512_store_ps(result.add(i), r1);
513 _mm512_store_ps(result.add(i + 16), r2);
514 _mm512_store_ps(result.add(i + 32), r3);
515 _mm512_store_ps(result.add(i + 48), r4);
516
517 i += 64;
518 }
519 }
520
521 while i + 16 <= len {
522 let a_vec = _mm512_loadu_ps(a.add(i));
523 let b_vec = _mm512_loadu_ps(b.add(i));
524 let result_vec = _mm512_mul_ps(a_vec, b_vec);
525 _mm512_storeu_ps(result.add(i), result_vec);
526 i += 16;
527 }
528
529 while i < len {
530 *result.add(i) = *a.add(i) * *b.add(i);
531 i += 1;
532 }
533}
534
535#[cfg(target_arch = "x86_64")]
536#[inline]
537#[target_feature(enable = "avx2")]
538unsafe fn avx2_mul_f32_inner(a: *const f32, b: *const f32, result: *mut f32, len: usize) {
539 use std::arch::x86_64::*;
540
541 let mut i = 0;
542 const PREFETCH_DISTANCE: usize = 256;
543
544 let a_aligned = (a as usize) % 32 == 0;
545 let b_aligned = (b as usize) % 32 == 0;
546 let result_aligned = (result as usize) % 32 == 0;
547
548 if a_aligned && b_aligned && result_aligned && len >= 64 {
549 while i + 64 <= len {
550 if i + PREFETCH_DISTANCE < len {
551 _mm_prefetch(a.add(i + PREFETCH_DISTANCE) as *const i8, _MM_HINT_T0);
552 _mm_prefetch(b.add(i + PREFETCH_DISTANCE) as *const i8, _MM_HINT_T0);
553 }
554
555 let a1 = _mm256_load_ps(a.add(i));
556 let b1 = _mm256_load_ps(b.add(i));
557 let a2 = _mm256_load_ps(a.add(i + 8));
558 let b2 = _mm256_load_ps(b.add(i + 8));
559 let a3 = _mm256_load_ps(a.add(i + 16));
560 let b3 = _mm256_load_ps(b.add(i + 16));
561 let a4 = _mm256_load_ps(a.add(i + 24));
562 let b4 = _mm256_load_ps(b.add(i + 24));
563 let a5 = _mm256_load_ps(a.add(i + 32));
564 let b5 = _mm256_load_ps(b.add(i + 32));
565 let a6 = _mm256_load_ps(a.add(i + 40));
566 let b6 = _mm256_load_ps(b.add(i + 40));
567 let a7 = _mm256_load_ps(a.add(i + 48));
568 let b7 = _mm256_load_ps(b.add(i + 48));
569 let a8 = _mm256_load_ps(a.add(i + 56));
570 let b8 = _mm256_load_ps(b.add(i + 56));
571
572 let r1 = _mm256_mul_ps(a1, b1);
573 let r2 = _mm256_mul_ps(a2, b2);
574 let r3 = _mm256_mul_ps(a3, b3);
575 let r4 = _mm256_mul_ps(a4, b4);
576 let r5 = _mm256_mul_ps(a5, b5);
577 let r6 = _mm256_mul_ps(a6, b6);
578 let r7 = _mm256_mul_ps(a7, b7);
579 let r8 = _mm256_mul_ps(a8, b8);
580
581 _mm256_store_ps(result.add(i), r1);
582 _mm256_store_ps(result.add(i + 8), r2);
583 _mm256_store_ps(result.add(i + 16), r3);
584 _mm256_store_ps(result.add(i + 24), r4);
585 _mm256_store_ps(result.add(i + 32), r5);
586 _mm256_store_ps(result.add(i + 40), r6);
587 _mm256_store_ps(result.add(i + 48), r7);
588 _mm256_store_ps(result.add(i + 56), r8);
589
590 i += 64;
591 }
592 }
593
594 while i + 8 <= len {
595 let a_vec = _mm256_loadu_ps(a.add(i));
596 let b_vec = _mm256_loadu_ps(b.add(i));
597 let result_vec = _mm256_mul_ps(a_vec, b_vec);
598 _mm256_storeu_ps(result.add(i), result_vec);
599 i += 8;
600 }
601
602 while i < len {
603 *result.add(i) = *a.add(i) * *b.add(i);
604 i += 1;
605 }
606}
607
608#[cfg(target_arch = "x86_64")]
609#[inline]
610#[target_feature(enable = "sse")]
611unsafe fn sse_mul_f32_inner(a: *const f32, b: *const f32, result: *mut f32, len: usize) {
612 use std::arch::x86_64::*;
613
614 let mut i = 0;
615
616 while i + 16 <= len {
617 let a1 = _mm_loadu_ps(a.add(i));
618 let b1 = _mm_loadu_ps(b.add(i));
619 let a2 = _mm_loadu_ps(a.add(i + 4));
620 let b2 = _mm_loadu_ps(b.add(i + 4));
621 let a3 = _mm_loadu_ps(a.add(i + 8));
622 let b3 = _mm_loadu_ps(b.add(i + 8));
623 let a4 = _mm_loadu_ps(a.add(i + 12));
624 let b4 = _mm_loadu_ps(b.add(i + 12));
625
626 let r1 = _mm_mul_ps(a1, b1);
627 let r2 = _mm_mul_ps(a2, b2);
628 let r3 = _mm_mul_ps(a3, b3);
629 let r4 = _mm_mul_ps(a4, b4);
630
631 _mm_storeu_ps(result.add(i), r1);
632 _mm_storeu_ps(result.add(i + 4), r2);
633 _mm_storeu_ps(result.add(i + 8), r3);
634 _mm_storeu_ps(result.add(i + 12), r4);
635
636 i += 16;
637 }
638
639 while i + 4 <= len {
640 let a_vec = _mm_loadu_ps(a.add(i));
641 let b_vec = _mm_loadu_ps(b.add(i));
642 let result_vec = _mm_mul_ps(a_vec, b_vec);
643 _mm_storeu_ps(result.add(i), result_vec);
644 i += 4;
645 }
646
647 while i < len {
648 *result.add(i) = *a.add(i) * *b.add(i);
649 i += 1;
650 }
651}
652
653#[cfg(target_arch = "aarch64")]
654#[inline(always)]
655unsafe fn neon_mul_f32_inner(a: *const f32, b: *const f32, result: *mut f32, len: usize) {
656 use std::arch::aarch64::*;
657
658 let mut i = 0;
659
660 while i + 16 <= len {
661 let a1 = vld1q_f32(a.add(i));
662 let b1 = vld1q_f32(b.add(i));
663 let a2 = vld1q_f32(a.add(i + 4));
664 let b2 = vld1q_f32(b.add(i + 4));
665 let a3 = vld1q_f32(a.add(i + 8));
666 let b3 = vld1q_f32(b.add(i + 8));
667 let a4 = vld1q_f32(a.add(i + 12));
668 let b4 = vld1q_f32(b.add(i + 12));
669
670 let r1 = vmulq_f32(a1, b1);
671 let r2 = vmulq_f32(a2, b2);
672 let r3 = vmulq_f32(a3, b3);
673 let r4 = vmulq_f32(a4, b4);
674
675 vst1q_f32(result.add(i), r1);
676 vst1q_f32(result.add(i + 4), r2);
677 vst1q_f32(result.add(i + 8), r3);
678 vst1q_f32(result.add(i + 12), r4);
679
680 i += 16;
681 }
682
683 while i + 4 <= len {
684 let a_vec = vld1q_f32(a.add(i));
685 let b_vec = vld1q_f32(b.add(i));
686 let result_vec = vmulq_f32(a_vec, b_vec);
687 vst1q_f32(result.add(i), result_vec);
688 i += 4;
689 }
690
691 while i < len {
692 *result.add(i) = *a.add(i) * *b.add(i);
693 i += 1;
694 }
695}
696
697#[inline(always)]
698unsafe fn scalar_mul_f32_inner(a: *const f32, b: *const f32, result: *mut f32, len: usize) {
699 for i in 0..len {
700 *result.add(i) = *a.add(i) * *b.add(i);
701 }
702}
703
704#[cfg(test)]
705mod mul_tests {
706 use super::*;
707 use ::ndarray::Array1;
708
709 #[test]
710 fn test_ultra_optimized_mul() {
711 let a = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
712 let b = Array1::from_vec(vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
713
714 let result = simd_mul_f32_ultra_optimized(&a.view(), &b.view());
715
716 assert_eq!(result[0], 2.0);
717 assert_eq!(result[1], 6.0);
718 assert_eq!(result[2], 12.0);
719 assert_eq!(result[7], 72.0);
720 }
721}
722
723#[inline(always)]
754pub fn simd_dot_f32_ultra_optimized(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> f32 {
755 let len = a.len();
756 assert_eq!(len, b.len(), "Arrays must have same length");
757
758 #[cfg(target_arch = "x86_64")]
759 {
760 unsafe {
761 let a_ptr = a.as_slice().expect("Operation failed").as_ptr();
762 let b_ptr = b.as_slice().expect("Operation failed").as_ptr();
763
764 if is_x86_feature_detected!("avx512f") {
765 return avx512_dot_f32_inner(a_ptr, b_ptr, len);
766 } else if is_x86_feature_detected!("avx2") {
767 return avx2_dot_f32_inner(a_ptr, b_ptr, len);
768 } else if is_x86_feature_detected!("sse2") {
769 return sse_dot_f32_inner(a_ptr, b_ptr, len);
770 } else {
771 return scalar_dot_f32(a, b);
772 }
773 }
774 }
775
776 #[cfg(target_arch = "aarch64")]
777 unsafe {
778 let a_ptr = a.as_slice().expect("Operation failed").as_ptr();
779 let b_ptr = b.as_slice().expect("Operation failed").as_ptr();
780 return neon_dot_f32_inner(a_ptr, b_ptr, len);
781 }
782
783 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
784 {
785 scalar_dot_f32(a, b)
787 }
788}
789
790#[cfg(target_arch = "x86_64")]
791#[inline]
792#[target_feature(enable = "avx512f")]
793unsafe fn avx512_dot_f32_inner(a: *const f32, b: *const f32, len: usize) -> f32 {
794 use std::arch::x86_64::*;
795
796 const PREFETCH_DISTANCE: usize = 512;
797 const VECTOR_SIZE: usize = 16; const UNROLL_FACTOR: usize = 4;
799 const CHUNK_SIZE: usize = VECTOR_SIZE * UNROLL_FACTOR; let mut i = 0;
802
803 let mut acc1 = _mm512_setzero_ps();
805 let mut acc2 = _mm512_setzero_ps();
806 let mut acc3 = _mm512_setzero_ps();
807 let mut acc4 = _mm512_setzero_ps();
808
809 let a_aligned = (a as usize) % 64 == 0;
811 let b_aligned = (b as usize) % 64 == 0;
812
813 if a_aligned && b_aligned && len >= CHUNK_SIZE {
814 while i + CHUNK_SIZE <= len {
816 if i + PREFETCH_DISTANCE < len {
818 _mm_prefetch(a.add(i + PREFETCH_DISTANCE) as *const i8, _MM_HINT_T0);
819 _mm_prefetch(b.add(i + PREFETCH_DISTANCE) as *const i8, _MM_HINT_T0);
820 }
821
822 let a1 = _mm512_load_ps(a.add(i));
824 let a2 = _mm512_load_ps(a.add(i + 16));
825 let a3 = _mm512_load_ps(a.add(i + 32));
826 let a4 = _mm512_load_ps(a.add(i + 48));
827
828 let b1 = _mm512_load_ps(b.add(i));
829 let b2 = _mm512_load_ps(b.add(i + 16));
830 let b3 = _mm512_load_ps(b.add(i + 32));
831 let b4 = _mm512_load_ps(b.add(i + 48));
832
833 acc1 = _mm512_fmadd_ps(a1, b1, acc1);
835 acc2 = _mm512_fmadd_ps(a2, b2, acc2);
836 acc3 = _mm512_fmadd_ps(a3, b3, acc3);
837 acc4 = _mm512_fmadd_ps(a4, b4, acc4);
838
839 i += CHUNK_SIZE;
840 }
841 }
842
843 while i + VECTOR_SIZE <= len {
845 let a_vec = _mm512_loadu_ps(a.add(i));
846 let b_vec = _mm512_loadu_ps(b.add(i));
847 acc1 = _mm512_fmadd_ps(a_vec, b_vec, acc1);
848 i += VECTOR_SIZE;
849 }
850
851 let combined1 = _mm512_add_ps(acc1, acc2);
853 let combined2 = _mm512_add_ps(acc3, acc4);
854 let final_acc = _mm512_add_ps(combined1, combined2);
855
856 let mut result = _mm512_reduce_add_ps(final_acc);
858
859 while i < len {
861 result += *a.add(i) * *b.add(i);
862 i += 1;
863 }
864
865 result
866}
867
868#[cfg(target_arch = "x86_64")]
869#[inline]
870#[target_feature(enable = "avx2")]
871#[target_feature(enable = "fma")]
872unsafe fn avx2_dot_f32_inner(a: *const f32, b: *const f32, len: usize) -> f32 {
873 use std::arch::x86_64::*;
874
875 const PREFETCH_DISTANCE: usize = 256;
876 const VECTOR_SIZE: usize = 8; const UNROLL_FACTOR: usize = 8;
878 const CHUNK_SIZE: usize = VECTOR_SIZE * UNROLL_FACTOR; let mut i = 0;
881
882 let mut acc1 = _mm256_setzero_ps();
884 let mut acc2 = _mm256_setzero_ps();
885 let mut acc3 = _mm256_setzero_ps();
886 let mut acc4 = _mm256_setzero_ps();
887 let mut acc5 = _mm256_setzero_ps();
888 let mut acc6 = _mm256_setzero_ps();
889 let mut acc7 = _mm256_setzero_ps();
890 let mut acc8 = _mm256_setzero_ps();
891
892 let a_aligned = (a as usize) % 32 == 0;
894 let b_aligned = (b as usize) % 32 == 0;
895
896 if a_aligned && b_aligned && len >= CHUNK_SIZE {
897 while i + CHUNK_SIZE <= len {
899 if i + PREFETCH_DISTANCE < len {
901 _mm_prefetch(a.add(i + PREFETCH_DISTANCE) as *const i8, _MM_HINT_T0);
902 _mm_prefetch(b.add(i + PREFETCH_DISTANCE) as *const i8, _MM_HINT_T0);
903 }
904
905 let a1 = _mm256_load_ps(a.add(i));
907 let a2 = _mm256_load_ps(a.add(i + 8));
908 let a3 = _mm256_load_ps(a.add(i + 16));
909 let a4 = _mm256_load_ps(a.add(i + 24));
910 let a5 = _mm256_load_ps(a.add(i + 32));
911 let a6 = _mm256_load_ps(a.add(i + 40));
912 let a7 = _mm256_load_ps(a.add(i + 48));
913 let a8 = _mm256_load_ps(a.add(i + 56));
914
915 let b1 = _mm256_load_ps(b.add(i));
916 let b2 = _mm256_load_ps(b.add(i + 8));
917 let b3 = _mm256_load_ps(b.add(i + 16));
918 let b4 = _mm256_load_ps(b.add(i + 24));
919 let b5 = _mm256_load_ps(b.add(i + 32));
920 let b6 = _mm256_load_ps(b.add(i + 40));
921 let b7 = _mm256_load_ps(b.add(i + 48));
922 let b8 = _mm256_load_ps(b.add(i + 56));
923
924 acc1 = _mm256_fmadd_ps(a1, b1, acc1);
926 acc2 = _mm256_fmadd_ps(a2, b2, acc2);
927 acc3 = _mm256_fmadd_ps(a3, b3, acc3);
928 acc4 = _mm256_fmadd_ps(a4, b4, acc4);
929 acc5 = _mm256_fmadd_ps(a5, b5, acc5);
930 acc6 = _mm256_fmadd_ps(a6, b6, acc6);
931 acc7 = _mm256_fmadd_ps(a7, b7, acc7);
932 acc8 = _mm256_fmadd_ps(a8, b8, acc8);
933
934 i += CHUNK_SIZE;
935 }
936 }
937
938 while i + VECTOR_SIZE <= len {
940 let a_vec = _mm256_loadu_ps(a.add(i));
941 let b_vec = _mm256_loadu_ps(b.add(i));
942 acc1 = _mm256_fmadd_ps(a_vec, b_vec, acc1);
943 i += VECTOR_SIZE;
944 }
945
946 let combined1 = _mm256_add_ps(acc1, acc2);
948 let combined2 = _mm256_add_ps(acc3, acc4);
949 let combined3 = _mm256_add_ps(acc5, acc6);
950 let combined4 = _mm256_add_ps(acc7, acc8);
951
952 let combined12 = _mm256_add_ps(combined1, combined2);
953 let combined34 = _mm256_add_ps(combined3, combined4);
954 let final_acc = _mm256_add_ps(combined12, combined34);
955
956 let high = _mm256_extractf128_ps(final_acc, 1);
958 let low = _mm256_castps256_ps128(final_acc);
959 let sum128 = _mm_add_ps(low, high);
960
961 let shuf = _mm_shuffle_ps(sum128, sum128, 0b1110);
962 let sum_partial = _mm_add_ps(sum128, shuf);
963 let shuf2 = _mm_shuffle_ps(sum_partial, sum_partial, 0b0001);
964 let final_result = _mm_add_ps(sum_partial, shuf2);
965
966 let mut result = _mm_cvtss_f32(final_result);
967
968 while i < len {
970 result += *a.add(i) * *b.add(i);
971 i += 1;
972 }
973
974 result
975}
976
977#[cfg(target_arch = "x86_64")]
978#[inline]
979#[target_feature(enable = "sse2")]
980unsafe fn sse_dot_f32_inner(a: *const f32, b: *const f32, len: usize) -> f32 {
981 use std::arch::x86_64::*;
982
983 const VECTOR_SIZE: usize = 4; const UNROLL_FACTOR: usize = 4;
985 const CHUNK_SIZE: usize = VECTOR_SIZE * UNROLL_FACTOR; let mut i = 0;
988
989 let mut acc1 = _mm_setzero_ps();
991 let mut acc2 = _mm_setzero_ps();
992 let mut acc3 = _mm_setzero_ps();
993 let mut acc4 = _mm_setzero_ps();
994
995 while i + CHUNK_SIZE <= len {
997 let a1 = _mm_loadu_ps(a.add(i));
998 let a2 = _mm_loadu_ps(a.add(i + 4));
999 let a3 = _mm_loadu_ps(a.add(i + 8));
1000 let a4 = _mm_loadu_ps(a.add(i + 12));
1001
1002 let b1 = _mm_loadu_ps(b.add(i));
1003 let b2 = _mm_loadu_ps(b.add(i + 4));
1004 let b3 = _mm_loadu_ps(b.add(i + 8));
1005 let b4 = _mm_loadu_ps(b.add(i + 12));
1006
1007 let prod1 = _mm_mul_ps(a1, b1);
1008 let prod2 = _mm_mul_ps(a2, b2);
1009 let prod3 = _mm_mul_ps(a3, b3);
1010 let prod4 = _mm_mul_ps(a4, b4);
1011
1012 acc1 = _mm_add_ps(acc1, prod1);
1013 acc2 = _mm_add_ps(acc2, prod2);
1014 acc3 = _mm_add_ps(acc3, prod3);
1015 acc4 = _mm_add_ps(acc4, prod4);
1016
1017 i += CHUNK_SIZE;
1018 }
1019
1020 while i + VECTOR_SIZE <= len {
1022 let a_vec = _mm_loadu_ps(a.add(i));
1023 let b_vec = _mm_loadu_ps(b.add(i));
1024 let prod = _mm_mul_ps(a_vec, b_vec);
1025 acc1 = _mm_add_ps(acc1, prod);
1026 i += VECTOR_SIZE;
1027 }
1028
1029 let combined1 = _mm_add_ps(acc1, acc2);
1031 let combined2 = _mm_add_ps(acc3, acc4);
1032 let final_acc = _mm_add_ps(combined1, combined2);
1033
1034 let shuf = _mm_shuffle_ps(final_acc, final_acc, 0b1110);
1036 let sum_partial = _mm_add_ps(final_acc, shuf);
1037 let shuf2 = _mm_shuffle_ps(sum_partial, sum_partial, 0b0001);
1038 let final_result = _mm_add_ps(sum_partial, shuf2);
1039
1040 let mut result = _mm_cvtss_f32(final_result);
1041
1042 while i < len {
1044 result += *a.add(i) * *b.add(i);
1045 i += 1;
1046 }
1047
1048 result
1049}
1050
1051#[cfg(target_arch = "aarch64")]
1052#[inline(always)]
1053unsafe fn neon_dot_f32_inner(a: *const f32, b: *const f32, len: usize) -> f32 {
1054 use std::arch::aarch64::*;
1055
1056 const VECTOR_SIZE: usize = 4; const UNROLL_FACTOR: usize = 4;
1058 const CHUNK_SIZE: usize = VECTOR_SIZE * UNROLL_FACTOR; let mut i = 0;
1061
1062 let mut acc1 = vdupq_n_f32(0.0);
1064 let mut acc2 = vdupq_n_f32(0.0);
1065 let mut acc3 = vdupq_n_f32(0.0);
1066 let mut acc4 = vdupq_n_f32(0.0);
1067
1068 while i + CHUNK_SIZE <= len {
1070 let a1 = vld1q_f32(a.add(i));
1071 let a2 = vld1q_f32(a.add(i + 4));
1072 let a3 = vld1q_f32(a.add(i + 8));
1073 let a4 = vld1q_f32(a.add(i + 12));
1074
1075 let b1 = vld1q_f32(b.add(i));
1076 let b2 = vld1q_f32(b.add(i + 4));
1077 let b3 = vld1q_f32(b.add(i + 8));
1078 let b4 = vld1q_f32(b.add(i + 12));
1079
1080 acc1 = vfmaq_f32(acc1, a1, b1);
1082 acc2 = vfmaq_f32(acc2, a2, b2);
1083 acc3 = vfmaq_f32(acc3, a3, b3);
1084 acc4 = vfmaq_f32(acc4, a4, b4);
1085
1086 i += CHUNK_SIZE;
1087 }
1088
1089 while i + VECTOR_SIZE <= len {
1091 let a_vec = vld1q_f32(a.add(i));
1092 let b_vec = vld1q_f32(b.add(i));
1093 acc1 = vfmaq_f32(acc1, a_vec, b_vec);
1094 i += VECTOR_SIZE;
1095 }
1096
1097 let combined1 = vaddq_f32(acc1, acc2);
1099 let combined2 = vaddq_f32(acc3, acc4);
1100 let final_acc = vaddq_f32(combined1, combined2);
1101
1102 let mut result = vaddvq_f32(final_acc);
1104
1105 while i < len {
1107 result += *a.add(i) * *b.add(i);
1108 i += 1;
1109 }
1110
1111 result
1112}
1113
1114#[inline(always)]
1115fn scalar_dot_f32(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> f32 {
1116 let a_slice = a.as_slice().expect("Operation failed");
1117 let b_slice = b.as_slice().expect("Operation failed");
1118
1119 a_slice.iter().zip(b_slice.iter()).map(|(x, y)| x * y).sum()
1120}
1121
1122#[cfg(test)]
1123mod dot_tests {
1124 use super::*;
1125 use ndarray::Array1;
1126
1127 #[test]
1128 fn test_dot_product_ultra_optimized() {
1129 let a = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1130 let b = Array1::from_vec(vec![8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0]);
1131
1132 let result = simd_dot_f32_ultra_optimized(&a.view(), &b.view());
1133
1134 assert_eq!(result, 120.0);
1137 }
1138
1139 #[test]
1140 fn test_dot_product_large_array() {
1141 let size = 10000;
1142 let a = Array1::from_elem(size, 2.0f32);
1143 let b = Array1::from_elem(size, 3.0f32);
1144
1145 let result = simd_dot_f32_ultra_optimized(&a.view(), &b.view());
1146
1147 assert!((result - 60000.0).abs() < 0.001);
1149 }
1150}
1151
1152#[inline(always)]
1180pub fn simd_sum_f32_ultra_optimized(input: &ArrayView1<f32>) -> f32 {
1181 let len = input.len();
1182 if len == 0 {
1183 return 0.0;
1184 }
1185
1186 #[cfg(target_arch = "x86_64")]
1187 {
1188 unsafe {
1189 let ptr = input.as_slice().expect("Operation failed").as_ptr();
1190
1191 if is_x86_feature_detected!("avx512f") {
1192 return avx512_sum_f32_inner(ptr, len);
1193 } else if is_x86_feature_detected!("avx2") {
1194 return avx2_sum_f32_inner(ptr, len);
1195 } else if is_x86_feature_detected!("sse2") {
1196 return sse_sum_f32_inner(ptr, len);
1197 } else {
1198 return scalar_sum_f32(input);
1199 }
1200 }
1201 }
1202
1203 #[cfg(target_arch = "aarch64")]
1204 unsafe {
1205 let ptr = input.as_slice().expect("Operation failed").as_ptr();
1206 return neon_sum_f32_inner(ptr, len);
1207 }
1208
1209 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1210 {
1211 scalar_sum_f32(input)
1213 }
1214}
1215
1216#[cfg(target_arch = "x86_64")]
1217#[inline]
1218#[target_feature(enable = "avx512f")]
1219unsafe fn avx512_sum_f32_inner(ptr: *const f32, len: usize) -> f32 {
1220 use std::arch::x86_64::*;
1221
1222 const PREFETCH_DISTANCE: usize = 512;
1223 const VECTOR_SIZE: usize = 16; const UNROLL_FACTOR: usize = 4;
1225 const CHUNK_SIZE: usize = VECTOR_SIZE * UNROLL_FACTOR; let mut i = 0;
1228
1229 let mut acc1 = _mm512_setzero_ps();
1231 let mut acc2 = _mm512_setzero_ps();
1232 let mut acc3 = _mm512_setzero_ps();
1233 let mut acc4 = _mm512_setzero_ps();
1234
1235 let aligned = (ptr as usize) % 64 == 0;
1237
1238 if aligned && len >= CHUNK_SIZE {
1239 while i + CHUNK_SIZE <= len {
1241 if i + PREFETCH_DISTANCE < len {
1243 _mm_prefetch(ptr.add(i + PREFETCH_DISTANCE) as *const i8, _MM_HINT_T0);
1244 }
1245
1246 let v1 = _mm512_load_ps(ptr.add(i));
1248 let v2 = _mm512_load_ps(ptr.add(i + 16));
1249 let v3 = _mm512_load_ps(ptr.add(i + 32));
1250 let v4 = _mm512_load_ps(ptr.add(i + 48));
1251
1252 acc1 = _mm512_add_ps(acc1, v1);
1254 acc2 = _mm512_add_ps(acc2, v2);
1255 acc3 = _mm512_add_ps(acc3, v3);
1256 acc4 = _mm512_add_ps(acc4, v4);
1257
1258 i += CHUNK_SIZE;
1259 }
1260 }
1261
1262 while i + VECTOR_SIZE <= len {
1264 let v = _mm512_loadu_ps(ptr.add(i));
1265 acc1 = _mm512_add_ps(acc1, v);
1266 i += VECTOR_SIZE;
1267 }
1268
1269 let combined1 = _mm512_add_ps(acc1, acc2);
1271 let combined2 = _mm512_add_ps(acc3, acc4);
1272 let final_acc = _mm512_add_ps(combined1, combined2);
1273
1274 let mut result = _mm512_reduce_add_ps(final_acc);
1276
1277 while i < len {
1279 result += *ptr.add(i);
1280 i += 1;
1281 }
1282
1283 result
1284}
1285
1286#[cfg(target_arch = "x86_64")]
1287#[inline]
1288#[target_feature(enable = "avx2")]
1289unsafe fn avx2_sum_f32_inner(ptr: *const f32, len: usize) -> f32 {
1290 use std::arch::x86_64::*;
1291
1292 const PREFETCH_DISTANCE: usize = 256;
1293 const VECTOR_SIZE: usize = 8; const UNROLL_FACTOR: usize = 8;
1295 const CHUNK_SIZE: usize = VECTOR_SIZE * UNROLL_FACTOR; let mut i = 0;
1298
1299 let mut acc1 = _mm256_setzero_ps();
1301 let mut acc2 = _mm256_setzero_ps();
1302 let mut acc3 = _mm256_setzero_ps();
1303 let mut acc4 = _mm256_setzero_ps();
1304 let mut acc5 = _mm256_setzero_ps();
1305 let mut acc6 = _mm256_setzero_ps();
1306 let mut acc7 = _mm256_setzero_ps();
1307 let mut acc8 = _mm256_setzero_ps();
1308
1309 let aligned = (ptr as usize) % 32 == 0;
1311
1312 if aligned && len >= CHUNK_SIZE {
1313 while i + CHUNK_SIZE <= len {
1315 if i + PREFETCH_DISTANCE < len {
1317 _mm_prefetch(ptr.add(i + PREFETCH_DISTANCE) as *const i8, _MM_HINT_T0);
1318 }
1319
1320 let v1 = _mm256_load_ps(ptr.add(i));
1322 let v2 = _mm256_load_ps(ptr.add(i + 8));
1323 let v3 = _mm256_load_ps(ptr.add(i + 16));
1324 let v4 = _mm256_load_ps(ptr.add(i + 24));
1325 let v5 = _mm256_load_ps(ptr.add(i + 32));
1326 let v6 = _mm256_load_ps(ptr.add(i + 40));
1327 let v7 = _mm256_load_ps(ptr.add(i + 48));
1328 let v8 = _mm256_load_ps(ptr.add(i + 56));
1329
1330 acc1 = _mm256_add_ps(acc1, v1);
1332 acc2 = _mm256_add_ps(acc2, v2);
1333 acc3 = _mm256_add_ps(acc3, v3);
1334 acc4 = _mm256_add_ps(acc4, v4);
1335 acc5 = _mm256_add_ps(acc5, v5);
1336 acc6 = _mm256_add_ps(acc6, v6);
1337 acc7 = _mm256_add_ps(acc7, v7);
1338 acc8 = _mm256_add_ps(acc8, v8);
1339
1340 i += CHUNK_SIZE;
1341 }
1342 }
1343
1344 while i + VECTOR_SIZE <= len {
1346 let v = _mm256_loadu_ps(ptr.add(i));
1347 acc1 = _mm256_add_ps(acc1, v);
1348 i += VECTOR_SIZE;
1349 }
1350
1351 let combined1 = _mm256_add_ps(acc1, acc2);
1353 let combined2 = _mm256_add_ps(acc3, acc4);
1354 let combined3 = _mm256_add_ps(acc5, acc6);
1355 let combined4 = _mm256_add_ps(acc7, acc8);
1356
1357 let combined12 = _mm256_add_ps(combined1, combined2);
1358 let combined34 = _mm256_add_ps(combined3, combined4);
1359 let final_acc = _mm256_add_ps(combined12, combined34);
1360
1361 let high = _mm256_extractf128_ps(final_acc, 1);
1363 let low = _mm256_castps256_ps128(final_acc);
1364 let sum128 = _mm_add_ps(low, high);
1365
1366 let shuf = _mm_shuffle_ps(sum128, sum128, 0b1110);
1367 let sum_partial = _mm_add_ps(sum128, shuf);
1368 let shuf2 = _mm_shuffle_ps(sum_partial, sum_partial, 0b0001);
1369 let final_result = _mm_add_ps(sum_partial, shuf2);
1370
1371 let mut result = _mm_cvtss_f32(final_result);
1372
1373 while i < len {
1375 result += *ptr.add(i);
1376 i += 1;
1377 }
1378
1379 result
1380}
1381
1382#[cfg(target_arch = "x86_64")]
1383#[inline]
1384#[target_feature(enable = "sse2")]
1385unsafe fn sse_sum_f32_inner(ptr: *const f32, len: usize) -> f32 {
1386 use std::arch::x86_64::*;
1387
1388 const VECTOR_SIZE: usize = 4; const UNROLL_FACTOR: usize = 4;
1390 const CHUNK_SIZE: usize = VECTOR_SIZE * UNROLL_FACTOR; let mut i = 0;
1393
1394 let mut acc1 = _mm_setzero_ps();
1396 let mut acc2 = _mm_setzero_ps();
1397 let mut acc3 = _mm_setzero_ps();
1398 let mut acc4 = _mm_setzero_ps();
1399
1400 while i + CHUNK_SIZE <= len {
1402 let v1 = _mm_loadu_ps(ptr.add(i));
1403 let v2 = _mm_loadu_ps(ptr.add(i + 4));
1404 let v3 = _mm_loadu_ps(ptr.add(i + 8));
1405 let v4 = _mm_loadu_ps(ptr.add(i + 12));
1406
1407 acc1 = _mm_add_ps(acc1, v1);
1408 acc2 = _mm_add_ps(acc2, v2);
1409 acc3 = _mm_add_ps(acc3, v3);
1410 acc4 = _mm_add_ps(acc4, v4);
1411
1412 i += CHUNK_SIZE;
1413 }
1414
1415 while i + VECTOR_SIZE <= len {
1417 let v = _mm_loadu_ps(ptr.add(i));
1418 acc1 = _mm_add_ps(acc1, v);
1419 i += VECTOR_SIZE;
1420 }
1421
1422 let combined1 = _mm_add_ps(acc1, acc2);
1424 let combined2 = _mm_add_ps(acc3, acc4);
1425 let final_acc = _mm_add_ps(combined1, combined2);
1426
1427 let shuf = _mm_shuffle_ps(final_acc, final_acc, 0b1110);
1429 let sum_partial = _mm_add_ps(final_acc, shuf);
1430 let shuf2 = _mm_shuffle_ps(sum_partial, sum_partial, 0b0001);
1431 let final_result = _mm_add_ps(sum_partial, shuf2);
1432
1433 let mut result = _mm_cvtss_f32(final_result);
1434
1435 while i < len {
1437 result += *ptr.add(i);
1438 i += 1;
1439 }
1440
1441 result
1442}
1443
1444#[cfg(target_arch = "aarch64")]
1445#[inline(always)]
1446unsafe fn neon_sum_f32_inner(ptr: *const f32, len: usize) -> f32 {
1447 use std::arch::aarch64::*;
1448
1449 const VECTOR_SIZE: usize = 4; const UNROLL_FACTOR: usize = 4;
1451 const CHUNK_SIZE: usize = VECTOR_SIZE * UNROLL_FACTOR; let mut i = 0;
1454
1455 let mut acc1 = vdupq_n_f32(0.0);
1457 let mut acc2 = vdupq_n_f32(0.0);
1458 let mut acc3 = vdupq_n_f32(0.0);
1459 let mut acc4 = vdupq_n_f32(0.0);
1460
1461 while i + CHUNK_SIZE <= len {
1463 let v1 = vld1q_f32(ptr.add(i));
1464 let v2 = vld1q_f32(ptr.add(i + 4));
1465 let v3 = vld1q_f32(ptr.add(i + 8));
1466 let v4 = vld1q_f32(ptr.add(i + 12));
1467
1468 acc1 = vaddq_f32(acc1, v1);
1469 acc2 = vaddq_f32(acc2, v2);
1470 acc3 = vaddq_f32(acc3, v3);
1471 acc4 = vaddq_f32(acc4, v4);
1472
1473 i += CHUNK_SIZE;
1474 }
1475
1476 while i + VECTOR_SIZE <= len {
1478 let v = vld1q_f32(ptr.add(i));
1479 acc1 = vaddq_f32(acc1, v);
1480 i += VECTOR_SIZE;
1481 }
1482
1483 let combined1 = vaddq_f32(acc1, acc2);
1485 let combined2 = vaddq_f32(acc3, acc4);
1486 let final_acc = vaddq_f32(combined1, combined2);
1487
1488 let mut result = vaddvq_f32(final_acc);
1490
1491 while i < len {
1493 result += *ptr.add(i);
1494 i += 1;
1495 }
1496
1497 result
1498}
1499
1500#[inline(always)]
1501fn scalar_sum_f32(input: &ArrayView1<f32>) -> f32 {
1502 let slice = input.as_slice().expect("Operation failed");
1503 slice.iter().sum()
1504}
1505
1506#[cfg(test)]
1507mod sum_tests {
1508 use super::*;
1509 use ndarray::Array1;
1510
1511 #[test]
1512 fn test_sum_ultra_optimized() {
1513 let a = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1514
1515 let result = simd_sum_f32_ultra_optimized(&a.view());
1516
1517 assert_eq!(result, 36.0);
1519 }
1520
1521 #[test]
1522 fn test_sum_large_array() {
1523 let size = 10000;
1524 let a = Array1::from_elem(size, 2.5f32);
1525
1526 let result = simd_sum_f32_ultra_optimized(&a.view());
1527
1528 assert!((result - 25000.0).abs() < 0.001);
1530 }
1531
1532 #[test]
1533 fn test_sum_empty() {
1534 let a = Array1::<f32>::from_vec(vec![]);
1535
1536 let result = simd_sum_f32_ultra_optimized(&a.view());
1537
1538 assert_eq!(result, 0.0);
1539 }
1540}