1use crate::simd_ops::SimdUnifiedOps;
7use ::ndarray::{Array, Array1, ArrayView1, ArrayView2, Dimension, Zip};
8use num_traits::Float;
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::simd::detect::get_cpu_features;
13use crate::simd::dot::{simd_div_f32, simd_mul_f32, simd_mul_f32_fast};
14use crate::simd::traits::SimdOps;
15
16#[allow(dead_code)]
20pub fn simd_binary_op<F, S1, S2, D>(
21 a: &crate::ndarray::ArrayBase<S1, D>,
22 b: &crate::ndarray::ArrayBase<S2, D>,
23 op: fn(F, F) -> F,
24) -> Array<F, D>
25where
26 F: SimdOps + Float + SimdUnifiedOps,
27 S1: crate::ndarray::Data<Elem = F>,
28 S2: crate::ndarray::Data<Elem = F>,
29 D: Dimension,
30{
31 let mut result = Array::zeros(a.raw_dim());
32 Zip::from(&mut result)
33 .and(a)
34 .and(b)
35 .for_each(|r, &a, &b| *r = op(a, b));
36 result
37}
38
39#[allow(dead_code)]
43pub fn simd_mul_f32_ultra(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
44 assert_eq!(a.len(), b.len(), "Arrays must have the same length");
45
46 #[cfg(target_arch = "x86_64")]
47 {
48 let len = a.len();
49 let mut result = vec![0.0f32; len];
50
51 let a_ptr = a.as_slice().expect("Operation failed").as_ptr();
52 let b_ptr = b.as_slice().expect("Operation failed").as_ptr();
53 let result_ptr = result.as_mut_ptr();
54
55 let features = get_cpu_features();
56
57 use std::arch::x86_64::*;
58
59 if features.has_avx2 {
60 unsafe {
61 let mut i = 0;
62 const PREFETCH_DISTANCE: usize = 256;
63
64 let a_aligned = (a_ptr as usize) % 32 == 0;
66 let b_aligned = (b_ptr as usize) % 32 == 0;
67 let result_aligned = (result_ptr as usize) % 32 == 0;
68
69 if a_aligned && b_aligned && result_aligned && len >= 64 {
70 while i + 64 <= len {
72 if i + PREFETCH_DISTANCE < len {
74 _mm_prefetch(
75 a_ptr.add(i + PREFETCH_DISTANCE) as *const i8,
76 _MM_HINT_T0,
77 );
78 _mm_prefetch(
79 b_ptr.add(i + PREFETCH_DISTANCE) as *const i8,
80 _MM_HINT_T0,
81 );
82 }
83
84 let a_vec1 = _mm256_load_ps(a_ptr.add(i));
86 let b_vec1 = _mm256_load_ps(b_ptr.add(i));
87 let result_vec1 = _mm256_mul_ps(a_vec1, b_vec1);
88
89 let a_vec2 = _mm256_load_ps(a_ptr.add(i + 8));
90 let b_vec2 = _mm256_load_ps(b_ptr.add(i + 8));
91 let result_vec2 = _mm256_mul_ps(a_vec2, b_vec2);
92
93 let a_vec3 = _mm256_load_ps(a_ptr.add(i + 16));
94 let b_vec3 = _mm256_load_ps(b_ptr.add(i + 16));
95 let result_vec3 = _mm256_mul_ps(a_vec3, b_vec3);
96
97 let a_vec4 = _mm256_load_ps(a_ptr.add(i + 24));
98 let b_vec4 = _mm256_load_ps(b_ptr.add(i + 24));
99 let result_vec4 = _mm256_mul_ps(a_vec4, b_vec4);
100
101 let a_vec5 = _mm256_load_ps(a_ptr.add(i + 32));
102 let b_vec5 = _mm256_load_ps(b_ptr.add(i + 32));
103 let result_vec5 = _mm256_mul_ps(a_vec5, b_vec5);
104
105 let a_vec6 = _mm256_load_ps(a_ptr.add(i + 40));
106 let b_vec6 = _mm256_load_ps(b_ptr.add(i + 40));
107 let result_vec6 = _mm256_mul_ps(a_vec6, b_vec6);
108
109 let a_vec7 = _mm256_load_ps(a_ptr.add(i + 48));
110 let b_vec7 = _mm256_load_ps(b_ptr.add(i + 48));
111 let result_vec7 = _mm256_mul_ps(a_vec7, b_vec7);
112
113 let a_vec8 = _mm256_load_ps(a_ptr.add(i + 56));
114 let b_vec8 = _mm256_load_ps(b_ptr.add(i + 56));
115 let result_vec8 = _mm256_mul_ps(a_vec8, b_vec8);
116
117 _mm256_store_ps(result_ptr.add(i), result_vec1);
118 _mm256_store_ps(result_ptr.add(i + 8), result_vec2);
119 _mm256_store_ps(result_ptr.add(i + 16), result_vec3);
120 _mm256_store_ps(result_ptr.add(i + 24), result_vec4);
121 _mm256_store_ps(result_ptr.add(i + 32), result_vec5);
122 _mm256_store_ps(result_ptr.add(i + 40), result_vec6);
123 _mm256_store_ps(result_ptr.add(i + 48), result_vec7);
124 _mm256_store_ps(result_ptr.add(i + 56), result_vec8);
125
126 i += 64;
127 }
128 }
129
130 while i + 8 <= len {
132 let a_vec = _mm256_loadu_ps(a_ptr.add(i));
133 let b_vec = _mm256_loadu_ps(b_ptr.add(i));
134 let result_vec = _mm256_mul_ps(a_vec, b_vec);
135 _mm256_storeu_ps(result_ptr.add(i), result_vec);
136 i += 8;
137 }
138
139 while i < len {
141 *result_ptr.add(i) = *a_ptr.add(i) * *b_ptr.add(i);
142 i += 1;
143 }
144 }
145 } else {
146 for i in 0..len {
148 result[i] = a[i] * b[i];
149 }
150 }
151
152 return Array1::from_vec(result);
153 }
154
155 #[cfg(not(target_arch = "x86_64"))]
156 {
157 (a * b).to_owned()
159 }
160}
161
162#[allow(dead_code)]
167pub fn simd_fused_multiply_add_f32(
168 a: &ArrayView1<f32>,
169 b: &ArrayView1<f32>,
170 c: &ArrayView1<f32>,
171) -> Array1<f32> {
172 assert_eq!(a.len(), b.len(), "Arrays a and b must have the same length");
173 assert_eq!(a.len(), c.len(), "Arrays a and c must have the same length");
174
175 let len = a.len();
176 let mut result = Vec::with_capacity(len);
177
178 let features = get_cpu_features();
179
180 #[cfg(target_arch = "x86_64")]
181 {
182 use std::arch::x86_64::*;
183
184 if features.has_fma {
185 unsafe {
186 let mut i = 0;
187 while i + 8 <= len {
189 let a_slice = &a.as_slice().expect("Operation failed")[i..i + 8];
190 let b_slice = &b.as_slice().expect("Operation failed")[i..i + 8];
191 let c_slice = &c.as_slice().expect("Operation failed")[i..i + 8];
192
193 let a_vec = _mm256_loadu_ps(a_slice.as_ptr());
194 let b_vec = _mm256_loadu_ps(b_slice.as_ptr());
195 let c_vec = _mm256_loadu_ps(c_slice.as_ptr());
196 let result_vec = _mm256_fmadd_ps(a_vec, b_vec, c_vec);
198
199 let mut temp = [0.0f32; 8];
200 _mm256_storeu_ps(temp.as_mut_ptr(), result_vec);
201 result.extend_from_slice(&temp);
202 i += 8;
203 }
204
205 for j in i..len {
207 result.push(a[j].mul_add(b[j], c[j]));
208 }
209 }
210 } else if is_x86_feature_detected!("avx2") {
211 unsafe {
212 let mut i = 0;
213 while i + 8 <= len {
215 let a_slice = &a.as_slice().expect("Operation failed")[i..i + 8];
216 let b_slice = &b.as_slice().expect("Operation failed")[i..i + 8];
217 let c_slice = &c.as_slice().expect("Operation failed")[i..i + 8];
218
219 let a_vec = _mm256_loadu_ps(a_slice.as_ptr());
220 let b_vec = _mm256_loadu_ps(b_slice.as_ptr());
221 let c_vec = _mm256_loadu_ps(c_slice.as_ptr());
222 let mul_result = _mm256_mul_ps(a_vec, b_vec);
223 let result_vec = _mm256_add_ps(mul_result, c_vec);
224
225 let mut temp = [0.0f32; 8];
226 _mm256_storeu_ps(temp.as_mut_ptr(), result_vec);
227 result.extend_from_slice(&temp);
228 i += 8;
229 }
230
231 for j in i..len {
233 result.push(a[j] * b[j] + c[j]);
234 }
235 }
236 } else {
237 for i in 0..len {
239 result.push(a[i] * b[i] + c[i]);
240 }
241 }
242 }
243
244 #[cfg(target_arch = "aarch64")]
245 {
246 use std::arch::aarch64::*;
247
248 if std::arch::is_aarch64_feature_detected!("neon") {
249 unsafe {
250 let mut i = 0;
251 while i + 4 <= len {
253 let a_slice = &a.as_slice().expect("Operation failed")[i..i + 4];
254 let b_slice = &b.as_slice().expect("Operation failed")[i..i + 4];
255 let c_slice = &c.as_slice().expect("Operation failed")[i..i + 4];
256
257 let a_vec = vld1q_f32(a_slice.as_ptr());
258 let b_vec = vld1q_f32(b_slice.as_ptr());
259 let c_vec = vld1q_f32(c_slice.as_ptr());
260 let result_vec = vfmaq_f32(c_vec, a_vec, b_vec);
262
263 let mut temp = [0.0f32; 4];
264 vst1q_f32(temp.as_mut_ptr(), result_vec);
265 result.extend_from_slice(&temp);
266 i += 4;
267 }
268
269 for j in i..len {
271 result.push(a[j].mul_add(b[j], c[j]));
272 }
273 }
274 } else {
275 for i in 0..len {
277 result.push(a[i] * b[i] + c[i]);
278 }
279 }
280 }
281
282 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
283 {
284 for i in 0..len {
286 result.push(a[i] * b[i] + c[i]);
287 }
288 }
289
290 Array1::from_vec(result)
291}
292
293#[allow(dead_code)]
295pub fn simd_fused_multiply_add_f64(
296 a: &ArrayView1<f64>,
297 b: &ArrayView1<f64>,
298 c: &ArrayView1<f64>,
299) -> Array1<f64> {
300 assert_eq!(a.len(), b.len(), "Arrays a and b must have the same length");
301 assert_eq!(a.len(), c.len(), "Arrays a and c must have the same length");
302
303 let len = a.len();
304 let mut result = Vec::with_capacity(len);
305
306 let features = get_cpu_features();
307
308 #[cfg(target_arch = "x86_64")]
309 {
310 use std::arch::x86_64::*;
311
312 if features.has_fma && features.has_avx2 {
313 unsafe {
314 let mut i = 0;
315 while i + 4 <= len {
317 let a_slice = &a.as_slice().expect("Operation failed")[i..i + 4];
318 let b_slice = &b.as_slice().expect("Operation failed")[i..i + 4];
319 let c_slice = &c.as_slice().expect("Operation failed")[i..i + 4];
320
321 let a_vec = _mm256_loadu_pd(a_slice.as_ptr());
322 let b_vec = _mm256_loadu_pd(b_slice.as_ptr());
323 let c_vec = _mm256_loadu_pd(c_slice.as_ptr());
324 let result_vec = _mm256_fmadd_pd(a_vec, b_vec, c_vec);
326
327 let mut temp = [0.0f64; 4];
328 _mm256_storeu_pd(temp.as_mut_ptr(), result_vec);
329 result.extend_from_slice(&temp);
330 i += 4;
331 }
332
333 for j in i..len {
335 result.push(a[j].mul_add(b[j], c[j]));
336 }
337 }
338 } else if is_x86_feature_detected!("avx2") {
339 unsafe {
340 let mut i = 0;
341 while i + 4 <= len {
343 let a_slice = &a.as_slice().expect("Operation failed")[i..i + 4];
344 let b_slice = &b.as_slice().expect("Operation failed")[i..i + 4];
345 let c_slice = &c.as_slice().expect("Operation failed")[i..i + 4];
346
347 let a_vec = _mm256_loadu_pd(a_slice.as_ptr());
348 let b_vec = _mm256_loadu_pd(b_slice.as_ptr());
349 let c_vec = _mm256_loadu_pd(c_slice.as_ptr());
350 let mul_result = _mm256_mul_pd(a_vec, b_vec);
351 let result_vec = _mm256_add_pd(mul_result, c_vec);
352
353 let mut temp = [0.0f64; 4];
354 _mm256_storeu_pd(temp.as_mut_ptr(), result_vec);
355 result.extend_from_slice(&temp);
356 i += 4;
357 }
358
359 for j in i..len {
361 result.push(a[j] * b[j] + c[j]);
362 }
363 }
364 } else {
365 for i in 0..len {
367 result.push(a[i].mul_add(b[i], c[i]));
368 }
369 }
370 }
371
372 #[cfg(target_arch = "aarch64")]
373 {
374 use std::arch::aarch64::*;
375
376 unsafe {
377 let mut i = 0;
378 while i + 2 <= len {
380 let a_slice = &a.as_slice().expect("Operation failed")[i..i + 2];
381 let b_slice = &b.as_slice().expect("Operation failed")[i..i + 2];
382 let c_slice = &c.as_slice().expect("Operation failed")[i..i + 2];
383
384 let a_vec = vld1q_f64(a_slice.as_ptr());
385 let b_vec = vld1q_f64(b_slice.as_ptr());
386 let c_vec = vld1q_f64(c_slice.as_ptr());
387 let result_vec = vfmaq_f64(c_vec, a_vec, b_vec);
389
390 let mut temp = [0.0f64; 2];
391 vst1q_f64(temp.as_mut_ptr(), result_vec);
392 result.extend_from_slice(&temp);
393 i += 2;
394 }
395
396 for j in i..len {
398 result.push(a[j].mul_add(b[j], c[j]));
399 }
400 }
401 }
402
403 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
404 {
405 for i in 0..len {
407 result.push(a[i].mul_add(b[i], c[i]));
408 }
409 }
410
411 Array1::from(result)
412}
413
414#[allow(dead_code)]
416pub fn simd_gemv_cache_optimized_f32(
417 alpha: f32,
418 a: &ArrayView2<f32>,
419 x: &ArrayView1<f32>,
420 beta: f32,
421 y: &mut Array1<f32>,
422) {
423 f32::simd_gemv(a, x, beta, y);
424
425 if alpha != 1.0 {
427 for elem in y.iter_mut() {
428 *elem *= alpha;
429 }
430 }
431}
432
433#[allow(dead_code)]
440pub fn simd_add_cache_optimized_f32(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
441 f32::simd_add_cache_optimized(a, b)
442}
443
444#[allow(dead_code)]
446pub fn simd_fma_advanced_optimized_f32(
447 a: &ArrayView1<f32>,
448 b: &ArrayView1<f32>,
449 c: &ArrayView1<f32>,
450) -> Array1<f32> {
451 f32::simd_fma_advanced_optimized(a, b, c)
452}
453
454#[allow(dead_code)]
456pub fn simd_adaptive_add_f32(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
457 f32::simd_add_adaptive(a, b)
458}
459
460#[allow(dead_code)]
462pub fn simd_add_cache_optimized_f64(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> Array1<f64> {
463 f64::simd_add_cache_optimized(a, b)
464}
465
466#[allow(dead_code)]
468pub fn simd_fma_advanced_optimized_f64(
469 a: &ArrayView1<f64>,
470 b: &ArrayView1<f64>,
471 c: &ArrayView1<f64>,
472) -> Array1<f64> {
473 f64::simd_fma_advanced_optimized(a, b, c)
474}
475
476#[allow(dead_code)]
478pub fn simd_adaptive_add_f64(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> Array1<f64> {
479 f64::simd_add_adaptive(a, b)
480}
481
482#[allow(dead_code)]
488pub fn simd_mul_f32_blazing(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
489 assert_eq!(a.len(), b.len(), "Arrays must have the same length");
490
491 let len = a.len();
492 let mut result = vec![0.0f32; len];
493
494 let a_ptr = a.as_slice().expect("Operation failed").as_ptr();
495 let b_ptr = b.as_slice().expect("Operation failed").as_ptr();
496 let result_ptr = result.as_mut_ptr();
497
498 let features = get_cpu_features();
499
500 #[cfg(target_arch = "x86_64")]
501 {
502 use std::arch::x86_64::*;
503
504 if features.has_avx2 {
505 unsafe {
506 let mut i = 0;
507
508 while i + 64 <= len {
510 let a1 = _mm256_loadu_ps(a_ptr.add(i));
512 let a2 = _mm256_loadu_ps(a_ptr.add(i + 8));
513 let a3 = _mm256_loadu_ps(a_ptr.add(i + 16));
514 let a4 = _mm256_loadu_ps(a_ptr.add(i + 24));
515 let a5 = _mm256_loadu_ps(a_ptr.add(i + 32));
516 let a6 = _mm256_loadu_ps(a_ptr.add(i + 40));
517 let a7 = _mm256_loadu_ps(a_ptr.add(i + 48));
518 let a8 = _mm256_loadu_ps(a_ptr.add(i + 56));
519
520 let b1 = _mm256_loadu_ps(b_ptr.add(i));
521 let b2 = _mm256_loadu_ps(b_ptr.add(i + 8));
522 let b3 = _mm256_loadu_ps(b_ptr.add(i + 16));
523 let b4 = _mm256_loadu_ps(b_ptr.add(i + 24));
524 let b5 = _mm256_loadu_ps(b_ptr.add(i + 32));
525 let b6 = _mm256_loadu_ps(b_ptr.add(i + 40));
526 let b7 = _mm256_loadu_ps(b_ptr.add(i + 48));
527 let b8 = _mm256_loadu_ps(b_ptr.add(i + 56));
528
529 let r1 = _mm256_mul_ps(a1, b1);
531 let r2 = _mm256_mul_ps(a2, b2);
532 let r3 = _mm256_mul_ps(a3, b3);
533 let r4 = _mm256_mul_ps(a4, b4);
534 let r5 = _mm256_mul_ps(a5, b5);
535 let r6 = _mm256_mul_ps(a6, b6);
536 let r7 = _mm256_mul_ps(a7, b7);
537 let r8 = _mm256_mul_ps(a8, b8);
538
539 _mm256_storeu_ps(result_ptr.add(i), r1);
541 _mm256_storeu_ps(result_ptr.add(i + 8), r2);
542 _mm256_storeu_ps(result_ptr.add(i + 16), r3);
543 _mm256_storeu_ps(result_ptr.add(i + 24), r4);
544 _mm256_storeu_ps(result_ptr.add(i + 32), r5);
545 _mm256_storeu_ps(result_ptr.add(i + 40), r6);
546 _mm256_storeu_ps(result_ptr.add(i + 48), r7);
547 _mm256_storeu_ps(result_ptr.add(i + 56), r8);
548
549 i += 64;
550 }
551
552 while i + 32 <= len {
554 let a1 = _mm256_loadu_ps(a_ptr.add(i));
555 let a2 = _mm256_loadu_ps(a_ptr.add(i + 8));
556 let a3 = _mm256_loadu_ps(a_ptr.add(i + 16));
557 let a4 = _mm256_loadu_ps(a_ptr.add(i + 24));
558
559 let b1 = _mm256_loadu_ps(b_ptr.add(i));
560 let b2 = _mm256_loadu_ps(b_ptr.add(i + 8));
561 let b3 = _mm256_loadu_ps(b_ptr.add(i + 16));
562 let b4 = _mm256_loadu_ps(b_ptr.add(i + 24));
563
564 let r1 = _mm256_mul_ps(a1, b1);
565 let r2 = _mm256_mul_ps(a2, b2);
566 let r3 = _mm256_mul_ps(a3, b3);
567 let r4 = _mm256_mul_ps(a4, b4);
568
569 _mm256_storeu_ps(result_ptr.add(i), r1);
570 _mm256_storeu_ps(result_ptr.add(i + 8), r2);
571 _mm256_storeu_ps(result_ptr.add(i + 16), r3);
572 _mm256_storeu_ps(result_ptr.add(i + 24), r4);
573
574 i += 32;
575 }
576
577 while i + 8 <= len {
579 let a_vec = _mm256_loadu_ps(a_ptr.add(i));
580 let b_vec = _mm256_loadu_ps(b_ptr.add(i));
581 let result_vec = _mm256_mul_ps(a_vec, b_vec);
582 _mm256_storeu_ps(result_ptr.add(i), result_vec);
583 i += 8;
584 }
585
586 while i < len {
588 *result_ptr.add(i) = *a_ptr.add(i) * *b_ptr.add(i);
589 i += 1;
590 }
591
592 return Array1::from(result);
593 }
594 }
595 }
596
597 simd_mul_f32_fast(a, b)
599}
600
601#[allow(dead_code)]
605pub fn simd_mul_f32_cache_optimized(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
606 assert_eq!(a.len(), b.len(), "Arrays must have the same length");
607
608 let len = a.len();
609 let mut result = vec![0.0f32; len];
610
611 let a_ptr = a.as_slice().expect("Operation failed").as_ptr();
612 let b_ptr = b.as_slice().expect("Operation failed").as_ptr();
613 let result_ptr = result.as_mut_ptr();
614
615 let features = get_cpu_features();
616
617 #[cfg(target_arch = "x86_64")]
618 {
619 use std::arch::x86_64::*;
620
621 if features.has_avx2 {
622 unsafe {
623 let mut i = 0;
624
625 const CACHE_LINE_ELEMENTS: usize = 16;
627
628 while i + CACHE_LINE_ELEMENTS <= len {
630 let a1 = _mm256_loadu_ps(a_ptr.add(i));
632 let a2 = _mm256_loadu_ps(a_ptr.add(i + 8));
633
634 let b1 = _mm256_loadu_ps(b_ptr.add(i));
635 let b2 = _mm256_loadu_ps(b_ptr.add(i + 8));
636
637 let r1 = _mm256_mul_ps(a1, b1);
639 let r2 = _mm256_mul_ps(a2, b2);
640
641 _mm256_storeu_ps(result_ptr.add(i), r1);
643 _mm256_storeu_ps(result_ptr.add(i + 8), r2);
644
645 i += CACHE_LINE_ELEMENTS;
646 }
647
648 while i + 8 <= len {
650 let a_vec = _mm256_loadu_ps(a_ptr.add(i));
651 let b_vec = _mm256_loadu_ps(b_ptr.add(i));
652 let result_vec = _mm256_mul_ps(a_vec, b_vec);
653 _mm256_storeu_ps(result_ptr.add(i), result_vec);
654 i += 8;
655 }
656
657 while i < len {
659 *result_ptr.add(i) = *a_ptr.add(i) * *b_ptr.add(i);
660 i += 1;
661 }
662
663 return Array1::from(result);
664 }
665 }
666 }
667
668 simd_mul_f32_fast(a, b)
670}
671
672#[allow(dead_code)]
676pub fn simd_mul_f32_lightweight(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
677 assert_eq!(a.len(), b.len(), "Arrays must have the same length");
678
679 let len = a.len();
680 let mut result = vec![0.0f32; len];
681
682 let a_ptr = a.as_slice().expect("Operation failed").as_ptr();
683 let b_ptr = b.as_slice().expect("Operation failed").as_ptr();
684 let result_ptr = result.as_mut_ptr();
685
686 #[cfg(target_arch = "x86_64")]
687 {
688 use std::arch::x86_64::*;
689
690 if std::arch::is_x86_feature_detected!("avx2") {
692 unsafe {
693 let mut i = 0;
694
695 while i + 8 <= len {
697 let a_vec = _mm256_loadu_ps(a_ptr.add(i));
698 let b_vec = _mm256_loadu_ps(b_ptr.add(i));
699 let result_vec = _mm256_mul_ps(a_vec, b_vec);
700 _mm256_storeu_ps(result_ptr.add(i), result_vec);
701 i += 8;
702 }
703
704 while i < len {
706 *result_ptr.add(i) = *a_ptr.add(i) * *b_ptr.add(i);
707 i += 1;
708 }
709
710 return Array1::from(result);
711 }
712 }
713 }
714
715 for i in 0..len {
717 result[i] = a[i] * b[i];
718 }
719
720 Array1::from_vec(result)
721}
722
723#[allow(dead_code)]
727pub fn simd_mul_f32_avx512(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
728 assert_eq!(a.len(), b.len(), "Arrays must have the same length");
729
730 let len = a.len();
731 let mut result = vec![0.0f32; len];
732
733 let a_ptr = a.as_slice().expect("Operation failed").as_ptr();
734 let b_ptr = b.as_slice().expect("Operation failed").as_ptr();
735 let result_ptr = result.as_mut_ptr();
736
737 let features = get_cpu_features();
738
739 #[cfg(target_arch = "x86_64")]
740 {
741 use std::arch::x86_64::*;
742
743 if features.has_avx512f {
744 unsafe {
746 let mut i = 0;
747
748 while i + 64 <= len {
750 let a1 = _mm256_loadu_ps(a_ptr.add(i));
758 let a2 = _mm256_loadu_ps(a_ptr.add(i + 8));
759 let a3 = _mm256_loadu_ps(a_ptr.add(i + 16));
760 let a4 = _mm256_loadu_ps(a_ptr.add(i + 24));
761 let a5 = _mm256_loadu_ps(a_ptr.add(i + 32));
762 let a6 = _mm256_loadu_ps(a_ptr.add(i + 40));
763 let a7 = _mm256_loadu_ps(a_ptr.add(i + 48));
764 let a8 = _mm256_loadu_ps(a_ptr.add(i + 56));
765
766 let b1 = _mm256_loadu_ps(b_ptr.add(i));
767 let b2 = _mm256_loadu_ps(b_ptr.add(i + 8));
768 let b3 = _mm256_loadu_ps(b_ptr.add(i + 16));
769 let b4 = _mm256_loadu_ps(b_ptr.add(i + 24));
770 let b5 = _mm256_loadu_ps(b_ptr.add(i + 32));
771 let b6 = _mm256_loadu_ps(b_ptr.add(i + 40));
772 let b7 = _mm256_loadu_ps(b_ptr.add(i + 48));
773 let b8 = _mm256_loadu_ps(b_ptr.add(i + 56));
774
775 let r1 = _mm256_mul_ps(a1, b1);
777 let r2 = _mm256_mul_ps(a2, b2);
778 let r3 = _mm256_mul_ps(a3, b3);
779 let r4 = _mm256_mul_ps(a4, b4);
780 let r5 = _mm256_mul_ps(a5, b5);
781 let r6 = _mm256_mul_ps(a6, b6);
782 let r7 = _mm256_mul_ps(a7, b7);
783 let r8 = _mm256_mul_ps(a8, b8);
784
785 _mm256_storeu_ps(result_ptr.add(i), r1);
787 _mm256_storeu_ps(result_ptr.add(i + 8), r2);
788 _mm256_storeu_ps(result_ptr.add(i + 16), r3);
789 _mm256_storeu_ps(result_ptr.add(i + 24), r4);
790 _mm256_storeu_ps(result_ptr.add(i + 32), r5);
791 _mm256_storeu_ps(result_ptr.add(i + 40), r6);
792 _mm256_storeu_ps(result_ptr.add(i + 48), r7);
793 _mm256_storeu_ps(result_ptr.add(i + 56), r8);
794
795 i += 64;
796 }
797
798 while i + 16 <= len {
800 let a1 = _mm256_loadu_ps(a_ptr.add(i));
801 let a2 = _mm256_loadu_ps(a_ptr.add(i + 8));
802 let b1 = _mm256_loadu_ps(b_ptr.add(i));
803 let b2 = _mm256_loadu_ps(b_ptr.add(i + 8));
804
805 let r1 = _mm256_mul_ps(a1, b1);
806 let r2 = _mm256_mul_ps(a2, b2);
807
808 _mm256_storeu_ps(result_ptr.add(i), r1);
809 _mm256_storeu_ps(result_ptr.add(i + 8), r2);
810
811 i += 16;
812 }
813
814 while i < len {
816 *result_ptr.add(i) = *a_ptr.add(i) * *b_ptr.add(i);
817 i += 1;
818 }
819
820 return Array1::from_vec(result);
821 }
822 }
823 }
824
825 simd_mul_f32_lightweight(a, b)
827}
828
829#[allow(dead_code)]
833pub fn simd_mul_f32_branchfree(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
834 assert_eq!(a.len(), b.len(), "Arrays must have the same length");
835
836 let len = a.len();
837 let mut result = vec![0.0f32; len];
838
839 let a_ptr = a.as_slice().expect("Operation failed").as_ptr();
840 let b_ptr = b.as_slice().expect("Operation failed").as_ptr();
841 let result_ptr = result.as_mut_ptr();
842
843 #[cfg(target_arch = "x86_64")]
844 {
845 use std::arch::x86_64::*;
846
847 if std::arch::is_x86_feature_detected!("avx2") {
848 unsafe {
849 let mut i = 0;
850
851 let vector_len = len & !7; while i < vector_len {
856 let a_vec = _mm256_loadu_ps(a_ptr.add(i));
857 let b_vec = _mm256_loadu_ps(b_ptr.add(i));
858 let result_vec = _mm256_mul_ps(a_vec, b_vec);
859 _mm256_storeu_ps(result_ptr.add(i), result_vec);
860 i += 8;
861 }
862
863 if i < len {
865 let remaining = len - i;
866 let mask_data = [
867 0xFFFFFFFFu32,
868 0xFFFFFFFFu32,
869 0xFFFFFFFFu32,
870 0xFFFFFFFFu32,
871 0xFFFFFFFFu32,
872 0xFFFFFFFFu32,
873 0xFFFFFFFFu32,
874 0xFFFFFFFFu32,
875 ];
876
877 let mut masked_data = mask_data;
879 for j in remaining..8 {
880 masked_data[j] = 0;
881 }
882
883 let mut a_temp = [0.0f32; 8];
885 let mut b_temp = [0.0f32; 8];
886
887 for j in 0..remaining {
888 a_temp[j] = *a_ptr.add(i + j);
889 b_temp[j] = *b_ptr.add(i + j);
890 }
891
892 let a_vec = _mm256_loadu_ps(a_temp.as_ptr());
893 let b_vec = _mm256_loadu_ps(b_temp.as_ptr());
894 let result_vec = _mm256_mul_ps(a_vec, b_vec);
895
896 let mut result_temp = [0.0f32; 8];
898 _mm256_storeu_ps(result_temp.as_mut_ptr(), result_vec);
899
900 for j in 0..remaining {
901 *result_ptr.add(i + j) = result_temp[j];
902 }
903 }
904
905 return Array1::from_vec(result);
906 }
907 }
908 }
909
910 for i in 0..len {
912 result[i] = a[i] * b[i];
913 }
914
915 Array1::from_vec(result)
916}
917
918#[allow(dead_code)]
922pub fn simd_mul_f32_bandwidth_saturated(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
923 assert_eq!(a.len(), b.len(), "Arrays must have the same length");
924
925 let len = a.len();
926 let mut result = vec![0.0f32; len];
927
928 let a_ptr = a.as_slice().expect("Operation failed").as_ptr();
929 let b_ptr = b.as_slice().expect("Operation failed").as_ptr();
930 let result_ptr = result.as_mut_ptr();
931
932 #[cfg(target_arch = "x86_64")]
933 {
934 use std::arch::x86_64::*;
935
936 if std::arch::is_x86_feature_detected!("avx2") {
937 unsafe {
938 let mut i = 0;
939
940 const CACHE_LINES_PER_ITERATION: usize = 4; const ELEMENTS_PER_ITERATION: usize = CACHE_LINES_PER_ITERATION * 16; while i + ELEMENTS_PER_ITERATION <= len {
945 let a1 = _mm256_loadu_ps(a_ptr.add(i));
948 let a2 = _mm256_loadu_ps(a_ptr.add(i + 8));
949 let b1 = _mm256_loadu_ps(b_ptr.add(i));
950 let b2 = _mm256_loadu_ps(b_ptr.add(i + 8));
951
952 let a3 = _mm256_loadu_ps(a_ptr.add(i + 16));
954 let a4 = _mm256_loadu_ps(a_ptr.add(i + 24));
955 let r1 = _mm256_mul_ps(a1, b1);
956 let r2 = _mm256_mul_ps(a2, b2);
957
958 let b3 = _mm256_loadu_ps(b_ptr.add(i + 16));
959 let b4 = _mm256_loadu_ps(b_ptr.add(i + 24));
960
961 let a5 = _mm256_loadu_ps(a_ptr.add(i + 32));
963 let a6 = _mm256_loadu_ps(a_ptr.add(i + 40));
964 let r3 = _mm256_mul_ps(a3, b3);
965 let r4 = _mm256_mul_ps(a4, b4);
966
967 let b5 = _mm256_loadu_ps(b_ptr.add(i + 32));
968 let b6 = _mm256_loadu_ps(b_ptr.add(i + 40));
969
970 let a7 = _mm256_loadu_ps(a_ptr.add(i + 48));
971 let a8 = _mm256_loadu_ps(a_ptr.add(i + 56));
972 let r5 = _mm256_mul_ps(a5, b5);
973 let r6 = _mm256_mul_ps(a6, b6);
974
975 let b7 = _mm256_loadu_ps(b_ptr.add(i + 48));
976 let b8 = _mm256_loadu_ps(b_ptr.add(i + 56));
977
978 let r7 = _mm256_mul_ps(a7, b7);
979 let r8 = _mm256_mul_ps(a8, b8);
980
981 _mm256_storeu_ps(result_ptr.add(i), r1);
983 _mm256_storeu_ps(result_ptr.add(i + 8), r2);
984 _mm256_storeu_ps(result_ptr.add(i + 16), r3);
985 _mm256_storeu_ps(result_ptr.add(i + 24), r4);
986 _mm256_storeu_ps(result_ptr.add(i + 32), r5);
987 _mm256_storeu_ps(result_ptr.add(i + 40), r6);
988 _mm256_storeu_ps(result_ptr.add(i + 48), r7);
989 _mm256_storeu_ps(result_ptr.add(i + 56), r8);
990
991 i += ELEMENTS_PER_ITERATION;
992 }
993
994 while i + 8 <= len {
996 let a_vec = _mm256_loadu_ps(a_ptr.add(i));
997 let b_vec = _mm256_loadu_ps(b_ptr.add(i));
998 let result_vec = _mm256_mul_ps(a_vec, b_vec);
999 _mm256_storeu_ps(result_ptr.add(i), result_vec);
1000 i += 8;
1001 }
1002
1003 while i < len {
1004 *result_ptr.add(i) = *a_ptr.add(i) * *b_ptr.add(i);
1005 i += 1;
1006 }
1007
1008 return Array1::from_vec(result);
1009 }
1010 }
1011 }
1012
1013 simd_mul_f32_lightweight(a, b)
1015}
1016
1017#[allow(dead_code)]
1021pub fn simd_mul_f32_ultimate(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
1022 let len = a.len();
1023
1024 if len >= 10_000_000 {
1026 simd_mul_f32_lightweight(a, b)
1028 } else if len >= 100_000 {
1029 simd_mul_f32_branchfree(a, b)
1031 } else if len >= 10_000 {
1032 simd_mul_f32_branchfree(a, b)
1034 } else {
1035 simd_mul_f32_lightweight(a, b)
1037 }
1038}
1039
1040pub fn simd_mul_f32_cacheline(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
1046 let len = a.len();
1047 let mut result = unsafe { Array1::uninit(len).assume_init() };
1048
1049 let a_ptr = a.as_ptr();
1050 let b_ptr = b.as_ptr();
1051 let result_ptr: *mut f32 = result.as_mut_ptr();
1052
1053 #[cfg(target_arch = "x86_64")]
1054 {
1055 use std::arch::x86_64::*;
1056
1057 unsafe {
1058 let cache_line_size = 16;
1060 let vector_end = len - (len % cache_line_size);
1061 let mut i = 0;
1062
1063 while i < vector_end {
1065 _mm_prefetch(a_ptr.add(i + cache_line_size) as *const i8, _MM_HINT_T0);
1067 _mm_prefetch(b_ptr.add(i + cache_line_size) as *const i8, _MM_HINT_T0);
1068
1069 let a_vec1 = _mm256_loadu_ps(a_ptr.add(i));
1071 let a_vec2 = _mm256_loadu_ps(a_ptr.add(i + 8));
1072 let b_vec1 = _mm256_loadu_ps(b_ptr.add(i));
1073 let b_vec2 = _mm256_loadu_ps(b_ptr.add(i + 8));
1074
1075 let result_vec1 = _mm256_mul_ps(a_vec1, b_vec1);
1077 let result_vec2 = _mm256_mul_ps(a_vec2, b_vec2);
1078
1079 _mm256_stream_ps(result_ptr.add(i), result_vec1);
1081 _mm256_stream_ps(result_ptr.add(i + 8), result_vec2);
1082
1083 i += cache_line_size;
1084 }
1085
1086 _mm_sfence();
1088
1089 while i < len {
1091 *result_ptr.add(i) = *a_ptr.add(i) * *b_ptr.add(i);
1092 i += 1;
1093 }
1094 }
1095 }
1096
1097 #[cfg(not(target_arch = "x86_64"))]
1098 {
1099 for i in 0..len {
1101 unsafe {
1102 *result_ptr.add(i) = *a_ptr.add(i) * *b_ptr.add(i);
1103 }
1104 }
1105 }
1106
1107 result
1108}
1109
1110pub fn simd_mul_f32_pipelined(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
1116 let len = a.len();
1117 let mut result = unsafe { Array1::uninit(len).assume_init() };
1118
1119 let a_ptr = a.as_ptr();
1120 let b_ptr = b.as_ptr();
1121 let result_ptr: *mut f32 = result.as_mut_ptr();
1122
1123 #[cfg(target_arch = "x86_64")]
1124 {
1125 use std::arch::x86_64::*;
1126
1127 unsafe {
1128 let mut i = 0;
1129 let block_size = 32; let block_end = len - (len % block_size);
1131
1132 while i < block_end {
1134 let a1 = _mm256_loadu_ps(a_ptr.add(i));
1136 let a2 = _mm256_loadu_ps(a_ptr.add(i + 8));
1137 let a3 = _mm256_loadu_ps(a_ptr.add(i + 16));
1138 let a4 = _mm256_loadu_ps(a_ptr.add(i + 24));
1139
1140 let b1 = _mm256_loadu_ps(b_ptr.add(i));
1141 let b2 = _mm256_loadu_ps(b_ptr.add(i + 8));
1142 let b3 = _mm256_loadu_ps(b_ptr.add(i + 16));
1143 let b4 = _mm256_loadu_ps(b_ptr.add(i + 24));
1144
1145 let r1 = _mm256_mul_ps(a1, b1);
1147 let r2 = _mm256_mul_ps(a2, b2);
1148 let r3 = _mm256_mul_ps(a3, b3);
1149 let r4 = _mm256_mul_ps(a4, b4);
1150
1151 _mm256_storeu_ps(result_ptr.add(i), r1);
1153 _mm256_storeu_ps(result_ptr.add(i + 8), r2);
1154 _mm256_storeu_ps(result_ptr.add(i + 16), r3);
1155 _mm256_storeu_ps(result_ptr.add(i + 24), r4);
1156
1157 i += block_size;
1158 }
1159
1160 while i + 16 <= len {
1162 let a1 = _mm256_loadu_ps(a_ptr.add(i));
1163 let a2 = _mm256_loadu_ps(a_ptr.add(i + 8));
1164 let b1 = _mm256_loadu_ps(b_ptr.add(i));
1165 let b2 = _mm256_loadu_ps(b_ptr.add(i + 8));
1166
1167 let r1 = _mm256_mul_ps(a1, b1);
1168 let r2 = _mm256_mul_ps(a2, b2);
1169
1170 _mm256_storeu_ps(result_ptr.add(i), r1);
1171 _mm256_storeu_ps(result_ptr.add(i + 8), r2);
1172
1173 i += 16;
1174 }
1175
1176 while i + 8 <= len {
1178 let a_vec = _mm256_loadu_ps(a_ptr.add(i));
1179 let b_vec = _mm256_loadu_ps(b_ptr.add(i));
1180 let result_vec = _mm256_mul_ps(a_vec, b_vec);
1181 _mm256_storeu_ps(result_ptr.add(i), result_vec);
1182 i += 8;
1183 }
1184
1185 while i < len {
1187 *result_ptr.add(i) = *a_ptr.add(i) * *b_ptr.add(i);
1188 i += 1;
1189 }
1190 }
1191 }
1192
1193 #[cfg(not(target_arch = "x86_64"))]
1194 {
1195 for i in 0..len {
1196 unsafe {
1197 *result_ptr.add(i) = *a_ptr.add(i) * *b_ptr.add(i);
1198 }
1199 }
1200 }
1201
1202 result
1203}
1204
1205pub fn simd_mul_f32_tlb_optimized(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
1211 let len = a.len();
1212 let mut result = unsafe { Array1::uninit(len).assume_init() };
1213
1214 let a_ptr = a.as_ptr();
1215 let b_ptr = b.as_ptr();
1216 let result_ptr: *mut f32 = result.as_mut_ptr();
1217
1218 #[cfg(target_arch = "x86_64")]
1219 {
1220 use std::arch::x86_64::*;
1221
1222 unsafe {
1223 const CHUNK_SIZE: usize = 512 * 1024 / 4; let mut pos = 0;
1227
1228 while pos < len {
1229 let chunk_end = std::cmp::min(pos + CHUNK_SIZE, len);
1230
1231 let prefetch_distance = 64;
1233 for j in (pos..chunk_end).step_by(prefetch_distance) {
1234 _mm_prefetch(a_ptr.add(j) as *const i8, _MM_HINT_T0);
1235 _mm_prefetch(b_ptr.add(j) as *const i8, _MM_HINT_T0);
1236 }
1237
1238 let mut i = pos;
1240
1241 while i + 64 <= chunk_end {
1243 let a1 = _mm256_loadu_ps(a_ptr.add(i));
1245 let a2 = _mm256_loadu_ps(a_ptr.add(i + 8));
1246 let a3 = _mm256_loadu_ps(a_ptr.add(i + 16));
1247 let a4 = _mm256_loadu_ps(a_ptr.add(i + 24));
1248 let a5 = _mm256_loadu_ps(a_ptr.add(i + 32));
1249 let a6 = _mm256_loadu_ps(a_ptr.add(i + 40));
1250 let a7 = _mm256_loadu_ps(a_ptr.add(i + 48));
1251 let a8 = _mm256_loadu_ps(a_ptr.add(i + 56));
1252
1253 let b1 = _mm256_loadu_ps(b_ptr.add(i));
1254 let b2 = _mm256_loadu_ps(b_ptr.add(i + 8));
1255 let b3 = _mm256_loadu_ps(b_ptr.add(i + 16));
1256 let b4 = _mm256_loadu_ps(b_ptr.add(i + 24));
1257 let b5 = _mm256_loadu_ps(b_ptr.add(i + 32));
1258 let b6 = _mm256_loadu_ps(b_ptr.add(i + 40));
1259 let b7 = _mm256_loadu_ps(b_ptr.add(i + 48));
1260 let b8 = _mm256_loadu_ps(b_ptr.add(i + 56));
1261
1262 let r1 = _mm256_mul_ps(a1, b1);
1264 let r2 = _mm256_mul_ps(a2, b2);
1265 let r3 = _mm256_mul_ps(a3, b3);
1266 let r4 = _mm256_mul_ps(a4, b4);
1267 let r5 = _mm256_mul_ps(a5, b5);
1268 let r6 = _mm256_mul_ps(a6, b6);
1269 let r7 = _mm256_mul_ps(a7, b7);
1270 let r8 = _mm256_mul_ps(a8, b8);
1271
1272 _mm256_storeu_ps(result_ptr.add(i), r1);
1274 _mm256_storeu_ps(result_ptr.add(i + 8), r2);
1275 _mm256_storeu_ps(result_ptr.add(i + 16), r3);
1276 _mm256_storeu_ps(result_ptr.add(i + 24), r4);
1277 _mm256_storeu_ps(result_ptr.add(i + 32), r5);
1278 _mm256_storeu_ps(result_ptr.add(i + 40), r6);
1279 _mm256_storeu_ps(result_ptr.add(i + 48), r7);
1280 _mm256_storeu_ps(result_ptr.add(i + 56), r8);
1281
1282 i += 64;
1283 }
1284
1285 while i + 8 <= chunk_end {
1287 let a_vec = _mm256_loadu_ps(a_ptr.add(i));
1288 let b_vec = _mm256_loadu_ps(b_ptr.add(i));
1289 let result_vec = _mm256_mul_ps(a_vec, b_vec);
1290 _mm256_storeu_ps(result_ptr.add(i), result_vec);
1291 i += 8;
1292 }
1293
1294 while i < chunk_end {
1296 *result_ptr.add(i) = *a_ptr.add(i) * *b_ptr.add(i);
1297 i += 1;
1298 }
1299
1300 pos = chunk_end;
1301 }
1302 }
1303 }
1304
1305 #[cfg(not(target_arch = "x86_64"))]
1306 {
1307 for i in 0..len {
1308 unsafe {
1309 *result_ptr.add(i) = *a_ptr.add(i) * *b_ptr.add(i);
1310 }
1311 }
1312 }
1313
1314 result
1315}
1316
1317#[allow(dead_code)]
1321pub fn simd_mul_f32_adaptive(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
1322 let len = a.len();
1323
1324 if len >= 100_000 {
1325 simd_mul_f32_blazing(a, b)
1327 } else if len >= 10_000 {
1328 simd_mul_f32_fast(a, b)
1330 } else {
1331 simd_mul_f32(a, b)
1333 }
1334}
1335
1336#[allow(dead_code)]
1340pub fn simd_add_auto(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
1341 simd_adaptive_add_f32(a, b)
1342}
1343
1344#[allow(dead_code)]
1348pub fn simd_mul_f32_hyperoptimized(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
1349 simd_mul_f32_ultra(a, b)
1351}
1352
1353#[allow(dead_code)]
1355pub fn simd_add_f32_adaptive(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
1356 f32::simd_add_adaptive(a, b)
1357}