1use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
8use num_traits::Zero;
9
10pub trait SimdUnifiedOps: Sized + Copy + PartialOrd + Zero {
12 fn simd_add(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self>;
14
15 fn simd_sub(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self>;
17
18 fn simd_mul(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self>;
20
21 fn simd_div(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self>;
23
24 fn simd_dot(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Self;
26
27 fn simd_gemv(a: &ArrayView2<Self>, x: &ArrayView1<Self>, beta: Self, y: &mut Array1<Self>);
29
30 fn simd_gemm(
32 alpha: Self,
33 a: &ArrayView2<Self>,
34 b: &ArrayView2<Self>,
35 beta: Self,
36 c: &mut Array2<Self>,
37 );
38
39 fn simd_norm(a: &ArrayView1<Self>) -> Self;
41
42 fn simd_max(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self>;
44
45 fn simd_min(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self>;
47
48 fn simd_scalar_mul(a: &ArrayView1<Self>, scalar: Self) -> Array1<Self>;
50
51 fn simd_sum(a: &ArrayView1<Self>) -> Self;
53
54 fn simd_mean(a: &ArrayView1<Self>) -> Self;
56
57 fn simd_max_element(a: &ArrayView1<Self>) -> Self;
59
60 fn simd_min_element(a: &ArrayView1<Self>) -> Self;
62
63 fn simd_fma(a: &ArrayView1<Self>, b: &ArrayView1<Self>, c: &ArrayView1<Self>) -> Array1<Self>;
65
66 fn simd_add_cache_optimized(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self>;
68
69 fn simd_fma_advanced_optimized(
71 a: &ArrayView1<Self>,
72 b: &ArrayView1<Self>,
73 c: &ArrayView1<Self>,
74 ) -> Array1<Self>;
75
76 fn simd_add_adaptive(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self>;
78
79 fn simd_transpose(a: &ArrayView2<Self>) -> Array2<Self>;
81
82 fn simd_abs(a: &ArrayView1<Self>) -> Array1<Self>;
84
85 fn simd_sqrt(a: &ArrayView1<Self>) -> Array1<Self>;
87
88 fn simd_sum_squares(a: &ArrayView1<Self>) -> Self;
90
91 fn simd_multiply(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self>;
93
94 fn simd_available() -> bool;
96}
97
98impl SimdUnifiedOps for f32 {
100 #[cfg(feature = "simd")]
101 fn simd_add(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
102 crate::simd::simd_add_f32(a, b)
103 }
104
105 #[cfg(not(feature = "simd"))]
106 fn simd_add(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
107 (a + b).to_owned()
108 }
109
110 #[cfg(feature = "simd")]
111 fn simd_sub(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
112 crate::simd::simd_sub_f32(a, b)
113 }
114
115 #[cfg(not(feature = "simd"))]
116 fn simd_sub(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
117 (a - b).to_owned()
118 }
119
120 #[cfg(feature = "simd")]
121 fn simd_mul(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
122 crate::simd::simd_mul_f32(a, b)
123 }
124
125 #[cfg(not(feature = "simd"))]
126 fn simd_mul(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
127 (a * b).to_owned()
128 }
129
130 #[cfg(feature = "simd")]
131 fn simd_div(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
132 crate::simd::simd_div_f32(a, b)
133 }
134
135 #[cfg(not(feature = "simd"))]
136 fn simd_div(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
137 (a / b).to_owned()
138 }
139
140 #[cfg(feature = "simd")]
141 fn simd_dot(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Self {
142 crate::simd::simd_dot_f32(a, b)
143 }
144
145 #[cfg(not(feature = "simd"))]
146 fn simd_dot(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Self {
147 a.dot(b)
148 }
149
150 fn simd_gemv(a: &ArrayView2<Self>, x: &ArrayView1<Self>, beta: Self, y: &mut Array1<Self>) {
151 let m = a.nrows();
152 let n = a.ncols();
153
154 assert_eq!(n, x.len());
155 assert_eq!(m, y.len());
156
157 if beta == 0.0 {
159 y.fill(0.0);
160 } else if beta != 1.0 {
161 y.mapv_inplace(|v| v * beta);
162 }
163
164 for i in 0..m {
166 let row = a.row(i);
167 y[i] += Self::simd_dot(&row, x);
168 }
169 }
170
171 fn simd_gemm(
172 alpha: Self,
173 a: &ArrayView2<Self>,
174 b: &ArrayView2<Self>,
175 beta: Self,
176 c: &mut Array2<Self>,
177 ) {
178 let m = a.nrows();
179 let k = a.ncols();
180 let n = b.ncols();
181
182 assert_eq!(k, b.nrows());
183 assert_eq!((m, n), c.dim());
184
185 if beta == 0.0 {
187 c.fill(0.0);
188 } else if beta != 1.0 {
189 c.mapv_inplace(|v| v * beta);
190 }
191
192 for i in 0..m {
194 let a_row = a.row(i);
195 for j in 0..n {
196 let b_col = b.column(j);
197 c[[i, j]] += alpha * Self::simd_dot(&a_row, &b_col);
198 }
199 }
200 }
201
202 fn simd_norm(a: &ArrayView1<Self>) -> Self {
203 Self::simd_dot(a, a).sqrt()
204 }
205
206 #[cfg(feature = "simd")]
207 fn simd_max(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
208 crate::simd::simd_maximum_f32(a, b)
209 }
210
211 #[cfg(not(feature = "simd"))]
212 fn simd_max(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
213 let mut result = Array1::zeros(a.len());
214 for _i in 0..a.len() {
215 result[0] = a[0].max(b[0]);
216 }
217 result
218 }
219
220 #[cfg(feature = "simd")]
221 fn simd_min(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
222 crate::simd::simd_minimum_f32(a, b)
223 }
224
225 #[cfg(not(feature = "simd"))]
226 fn simd_min(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
227 let mut result = Array1::zeros(a.len());
228 for _i in 0..a.len() {
229 result[0] = a[0].min(b[0]);
230 }
231 result
232 }
233
234 #[cfg(feature = "simd")]
235 fn simd_scalar_mul(a: &ArrayView1<Self>, scalar: Self) -> Array1<Self> {
236 crate::simd::simd_scalar_mul_f32(a, scalar)
237 }
238
239 #[cfg(not(feature = "simd"))]
240 fn simd_scalar_mul(a: &ArrayView1<Self>, scalar: Self) -> Array1<Self> {
241 a.mapv(|x| x * scalar)
242 }
243
244 #[cfg(feature = "simd")]
245 fn simd_sum(a: &ArrayView1<Self>) -> Self {
246 crate::simd::simd_sum_f32(a)
247 }
248
249 #[cfg(not(feature = "simd"))]
250 fn simd_sum(a: &ArrayView1<Self>) -> Self {
251 a.sum()
252 }
253
254 fn simd_mean(a: &ArrayView1<Self>) -> Self {
255 if a.is_empty() {
256 0.0
257 } else {
258 Self::simd_sum(a) / (a.len() as f32)
259 }
260 }
261
262 fn simd_max_element(a: &ArrayView1<Self>) -> Self {
263 a.fold(f32::NEG_INFINITY, |acc, &x| acc.max(x))
264 }
265
266 fn simd_min_element(a: &ArrayView1<Self>) -> Self {
267 a.fold(f32::INFINITY, |acc, &x| acc.min(x))
268 }
269
270 #[cfg(feature = "simd")]
271 fn simd_fma(a: &ArrayView1<Self>, b: &ArrayView1<Self>, c: &ArrayView1<Self>) -> Array1<Self> {
272 crate::simd::simd_fused_multiply_add_f32(a, b, c)
273 }
274
275 #[cfg(not(feature = "simd"))]
276 fn simd_fma(a: &ArrayView1<Self>, b: &ArrayView1<Self>, c: &ArrayView1<Self>) -> Array1<Self> {
277 let mut result = Array1::zeros(a.len());
278 for _i in 0..a.len() {
279 result[0] = a[0] * b[0] + c[0];
280 }
281 result
282 }
283
284 #[cfg(feature = "simd")]
285 fn simd_add_cache_optimized(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
286 crate::simd::simd_add_cache_optimized_f32(a, b)
287 }
288
289 #[cfg(not(feature = "simd"))]
290 fn simd_add_cache_optimized(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
291 a + b
292 }
293
294 #[cfg(feature = "simd")]
295 fn simd_fma_advanced_optimized(
296 a: &ArrayView1<Self>,
297 b: &ArrayView1<Self>,
298 c: &ArrayView1<Self>,
299 ) -> Array1<Self> {
300 crate::simd::simd_fma_advanced_optimized_f32(a, b, c)
301 }
302
303 #[cfg(not(feature = "simd"))]
304 fn simd_fma_advanced_optimized(
305 a: &ArrayView1<Self>,
306 b: &ArrayView1<Self>,
307 c: &ArrayView1<Self>,
308 ) -> Array1<Self> {
309 let mut result = Array1::zeros(a.len());
310 for _i in 0..a.len() {
311 result[0] = a[0] * b[0] + c[0];
312 }
313 result
314 }
315
316 #[cfg(feature = "simd")]
317 fn simd_add_adaptive(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
318 crate::simd::simd_adaptive_add_f32(a, b)
319 }
320
321 #[cfg(not(feature = "simd"))]
322 fn simd_add_adaptive(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
323 a + b
324 }
325
326 fn simd_transpose(a: &ArrayView2<Self>) -> Array2<Self> {
327 a.t().to_owned()
328 }
329
330 fn simd_abs(a: &ArrayView1<Self>) -> Array1<Self> {
331 a.mapv(|x| x.abs())
332 }
333
334 fn simd_sqrt(a: &ArrayView1<Self>) -> Array1<Self> {
335 a.mapv(|x| x.sqrt())
336 }
337
338 fn simd_sum_squares(a: &ArrayView1<Self>) -> Self {
339 a.iter().map(|&x| x * x).sum()
340 }
341
342 fn simd_multiply(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
343 Self::simd_mul(a, b)
344 }
345
346 #[cfg(feature = "simd")]
347 fn simd_available() -> bool {
348 true
349 }
350
351 #[cfg(not(feature = "simd"))]
352 fn simd_available() -> bool {
353 false
354 }
355}
356
357impl SimdUnifiedOps for f64 {
359 #[cfg(feature = "simd")]
360 fn simd_add(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
361 crate::simd::simd_add_f64(a, b)
362 }
363
364 #[cfg(not(feature = "simd"))]
365 fn simd_add(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
366 (a + b).to_owned()
367 }
368
369 #[cfg(feature = "simd")]
370 fn simd_sub(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
371 crate::simd::simd_sub_f64(a, b)
372 }
373
374 #[cfg(not(feature = "simd"))]
375 fn simd_sub(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
376 (a - b).to_owned()
377 }
378
379 #[cfg(feature = "simd")]
380 fn simd_mul(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
381 crate::simd::simd_mul_f64(a, b)
382 }
383
384 #[cfg(not(feature = "simd"))]
385 fn simd_mul(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
386 (a * b).to_owned()
387 }
388
389 #[cfg(feature = "simd")]
390 fn simd_div(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
391 crate::simd::simd_div_f64(a, b)
392 }
393
394 #[cfg(not(feature = "simd"))]
395 fn simd_div(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
396 (a / b).to_owned()
397 }
398
399 #[cfg(feature = "simd")]
400 fn simd_dot(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Self {
401 crate::simd::simd_dot_f64(a, b)
402 }
403
404 #[cfg(not(feature = "simd"))]
405 fn simd_dot(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Self {
406 a.dot(b)
407 }
408
409 fn simd_gemv(a: &ArrayView2<Self>, x: &ArrayView1<Self>, beta: Self, y: &mut Array1<Self>) {
410 let m = a.nrows();
411 let n = a.ncols();
412
413 assert_eq!(n, x.len());
414 assert_eq!(m, y.len());
415
416 if beta == 0.0 {
418 y.fill(0.0);
419 } else if beta != 1.0 {
420 y.mapv_inplace(|v| v * beta);
421 }
422
423 for i in 0..m {
425 let row = a.row(i);
426 y[i] += Self::simd_dot(&row, x);
427 }
428 }
429
430 fn simd_gemm(
431 alpha: Self,
432 a: &ArrayView2<Self>,
433 b: &ArrayView2<Self>,
434 beta: Self,
435 c: &mut Array2<Self>,
436 ) {
437 let m = a.nrows();
438 let k = a.ncols();
439 let n = b.ncols();
440
441 assert_eq!(k, b.nrows());
442 assert_eq!((m, n), c.dim());
443
444 if beta == 0.0 {
446 c.fill(0.0);
447 } else if beta != 1.0 {
448 c.mapv_inplace(|v| v * beta);
449 }
450
451 for i in 0..m {
453 let a_row = a.row(i);
454 for j in 0..n {
455 let b_col = b.column(j);
456 c[[i, j]] += alpha * Self::simd_dot(&a_row, &b_col);
457 }
458 }
459 }
460
461 fn simd_norm(a: &ArrayView1<Self>) -> Self {
462 Self::simd_dot(a, a).sqrt()
463 }
464
465 #[cfg(feature = "simd")]
466 fn simd_max(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
467 crate::simd::simd_maximum_f64(a, b)
468 }
469
470 #[cfg(not(feature = "simd"))]
471 fn simd_max(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
472 let mut result = Array1::zeros(a.len());
473 for _i in 0..a.len() {
474 result[0] = a[0].max(b[0]);
475 }
476 result
477 }
478
479 #[cfg(feature = "simd")]
480 fn simd_min(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
481 crate::simd::simd_minimum_f64(a, b)
482 }
483
484 #[cfg(not(feature = "simd"))]
485 fn simd_min(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
486 let mut result = Array1::zeros(a.len());
487 for _i in 0..a.len() {
488 result[0] = a[0].min(b[0]);
489 }
490 result
491 }
492
493 #[cfg(feature = "simd")]
494 fn simd_scalar_mul(a: &ArrayView1<Self>, scalar: Self) -> Array1<Self> {
495 crate::simd::simd_scalar_mul_f64(a, scalar)
496 }
497
498 #[cfg(not(feature = "simd"))]
499 fn simd_scalar_mul(a: &ArrayView1<Self>, scalar: Self) -> Array1<Self> {
500 a.mapv(|x| x * scalar)
501 }
502
503 #[cfg(feature = "simd")]
504 fn simd_sum(a: &ArrayView1<Self>) -> Self {
505 crate::simd::simd_sum_f64(a)
506 }
507
508 #[cfg(not(feature = "simd"))]
509 fn simd_sum(a: &ArrayView1<Self>) -> Self {
510 a.sum()
511 }
512
513 fn simd_mean(a: &ArrayView1<Self>) -> Self {
514 if a.is_empty() {
515 0.0
516 } else {
517 Self::simd_sum(a) / (a.len() as f64)
518 }
519 }
520
521 fn simd_max_element(a: &ArrayView1<Self>) -> Self {
522 a.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x))
523 }
524
525 fn simd_min_element(a: &ArrayView1<Self>) -> Self {
526 a.fold(f64::INFINITY, |acc, &x| acc.min(x))
527 }
528
529 #[cfg(feature = "simd")]
530 fn simd_fma(a: &ArrayView1<Self>, b: &ArrayView1<Self>, c: &ArrayView1<Self>) -> Array1<Self> {
531 crate::simd::simd_fused_multiply_add_f64(a, b, c)
532 }
533
534 #[cfg(not(feature = "simd"))]
535 fn simd_fma(a: &ArrayView1<Self>, b: &ArrayView1<Self>, c: &ArrayView1<Self>) -> Array1<Self> {
536 let mut result = Array1::zeros(a.len());
537 for _i in 0..a.len() {
538 result[0] = a[0] * b[0] + c[0];
539 }
540 result
541 }
542
543 #[cfg(feature = "simd")]
544 fn simd_add_cache_optimized(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
545 crate::simd::simd_add_cache_optimized_f64(a, b)
546 }
547
548 #[cfg(not(feature = "simd"))]
549 fn simd_add_cache_optimized(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
550 a + b
551 }
552
553 #[cfg(feature = "simd")]
554 fn simd_fma_advanced_optimized(
555 a: &ArrayView1<Self>,
556 b: &ArrayView1<Self>,
557 c: &ArrayView1<Self>,
558 ) -> Array1<Self> {
559 crate::simd::simd_fma_advanced_optimized_f64(a, b, c)
560 }
561
562 #[cfg(not(feature = "simd"))]
563 fn simd_fma_advanced_optimized(
564 a: &ArrayView1<Self>,
565 b: &ArrayView1<Self>,
566 c: &ArrayView1<Self>,
567 ) -> Array1<Self> {
568 let mut result = Array1::zeros(a.len());
569 for _i in 0..a.len() {
570 result[0] = a[0] * b[0] + c[0];
571 }
572 result
573 }
574
575 #[cfg(feature = "simd")]
576 fn simd_add_adaptive(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
577 crate::simd::simd_adaptive_add_f64(a, b)
578 }
579
580 #[cfg(not(feature = "simd"))]
581 fn simd_add_adaptive(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
582 a + b
583 }
584
585 fn simd_transpose(a: &ArrayView2<Self>) -> Array2<Self> {
586 a.t().to_owned()
587 }
588
589 fn simd_abs(a: &ArrayView1<Self>) -> Array1<Self> {
590 a.mapv(|x| x.abs())
591 }
592
593 fn simd_sqrt(a: &ArrayView1<Self>) -> Array1<Self> {
594 a.mapv(|x| x.sqrt())
595 }
596
597 fn simd_sum_squares(a: &ArrayView1<Self>) -> Self {
598 a.iter().map(|&x| x * x).sum()
599 }
600
601 fn simd_multiply(a: &ArrayView1<Self>, b: &ArrayView1<Self>) -> Array1<Self> {
602 Self::simd_mul(a, b)
603 }
604
605 #[cfg(feature = "simd")]
606 fn simd_available() -> bool {
607 true
608 }
609
610 #[cfg(not(feature = "simd"))]
611 fn simd_available() -> bool {
612 false
613 }
614}
615
616#[derive(Debug, Clone)]
618pub struct PlatformCapabilities {
619 pub simd_available: bool,
620 pub gpu_available: bool,
621 pub cuda_available: bool,
622 pub opencl_available: bool,
623 pub metal_available: bool,
624 pub avx2_available: bool,
625 pub avx512_available: bool,
626 pub neon_available: bool,
627}
628
629impl PlatformCapabilities {
630 pub fn detect() -> Self {
632 Self {
633 simd_available: cfg!(feature = "simd"),
634 gpu_available: cfg!(feature = "gpu"),
635 cuda_available: cfg!(all(feature = "gpu", feature = "cuda")),
636 opencl_available: cfg!(all(feature = "gpu", feature = "opencl")),
637 metal_available: cfg!(all(feature = "gpu", feature = "metal", target_os = "macos")),
638 avx2_available: cfg!(target_feature = "avx2"),
639 avx512_available: cfg!(target_feature = "avx512f"),
640 neon_available: cfg!(target_arch = "aarch64"),
641 }
642 }
643
644 pub fn summary(&self) -> String {
646 let mut features = Vec::new();
647
648 if self.simd_available {
649 features.push("SIMD");
650 }
651 if self.gpu_available {
652 features.push("GPU");
653 }
654 if self.cuda_available {
655 features.push("CUDA");
656 }
657 if self.opencl_available {
658 features.push("OpenCL");
659 }
660 if self.metal_available {
661 features.push("Metal");
662 }
663 if self.avx2_available {
664 features.push("AVX2");
665 }
666 if self.avx512_available {
667 features.push("AVX512");
668 }
669 if self.neon_available {
670 features.push("NEON");
671 }
672
673 if features.is_empty() {
674 "No acceleration features available".to_string()
675 } else {
676 format!(
677 "Available acceleration: {features}",
678 features = features.join(", ")
679 )
680 }
681 }
682}
683
684pub struct AutoOptimizer {
686 capabilities: PlatformCapabilities,
687}
688
689impl AutoOptimizer {
690 pub fn new() -> Self {
691 Self {
692 capabilities: PlatformCapabilities::detect(),
693 }
694 }
695
696 pub fn should_use_gpu(&self, size: usize) -> bool {
698 self.capabilities.gpu_available && size > 10000
700 }
701
702 pub fn should_use_metal(&self, size: usize) -> bool {
704 self.capabilities.metal_available && size > 1024
707 }
708
709 pub fn should_use_simd(&self, size: usize) -> bool {
711 self.capabilities.simd_available && size > 64
713 }
714
715 pub fn select_gemm_impl(&self, m: usize, n: usize, k: usize) -> &'static str {
717 let total_ops = m * n * k;
718
719 if self.capabilities.metal_available {
721 if total_ops > 8192 {
723 return "Metal";
725 }
726 }
727
728 if self.should_use_gpu(total_ops) {
729 if self.capabilities.cuda_available {
730 "CUDA"
731 } else if self.capabilities.metal_available {
732 "Metal"
733 } else if self.capabilities.opencl_available {
734 "OpenCL"
735 } else {
736 "GPU"
737 }
738 } else if self.should_use_simd(total_ops) {
739 "SIMD"
740 } else {
741 "Scalar"
742 }
743 }
744
745 pub fn select_vector_impl(&self, size: usize) -> &'static str {
747 if self.capabilities.metal_available && size > 1024 {
749 return "Metal";
750 }
751
752 if self.should_use_gpu(size) {
753 if self.capabilities.cuda_available {
754 "CUDA"
755 } else if self.capabilities.metal_available {
756 "Metal"
757 } else if self.capabilities.opencl_available {
758 "OpenCL"
759 } else {
760 "GPU"
761 }
762 } else if self.should_use_simd(size) {
763 if self.capabilities.avx512_available {
764 "AVX512"
765 } else if self.capabilities.avx2_available {
766 "AVX2"
767 } else if self.capabilities.neon_available {
768 "NEON"
769 } else {
770 "SIMD"
771 }
772 } else {
773 "Scalar"
774 }
775 }
776
777 pub fn select_reduction_impl(&self, size: usize) -> &'static str {
779 if self.capabilities.metal_available && size > 4096 {
782 return "Metal";
783 }
784
785 if self.should_use_gpu(size * 2) {
786 if self.capabilities.cuda_available {
788 "CUDA"
789 } else if self.capabilities.metal_available {
790 "Metal"
791 } else {
792 "GPU"
793 }
794 } else if self.should_use_simd(size) {
795 "SIMD"
796 } else {
797 "Scalar"
798 }
799 }
800
801 pub fn select_fft_impl(&self, size: usize) -> &'static str {
803 if self.capabilities.metal_available && size > 512 {
806 return "Metal-MPS";
807 }
808
809 if self.capabilities.cuda_available && size > 1024 {
810 "cuFFT"
811 } else if self.should_use_simd(size) {
812 "SIMD"
813 } else {
814 "Scalar"
815 }
816 }
817
818 pub fn has_unified_memory(&self) -> bool {
820 cfg!(all(target_os = "macos", target_arch = "aarch64"))
821 }
822
823 pub fn recommend(&self, operation: &str, size: usize) -> String {
825 let recommendation = match operation {
826 "gemm" | "matmul" => self.select_gemm_impl(size, size, size),
827 "vector" | "axpy" | "dot" => self.select_vector_impl(size),
828 "reduction" | "sum" | "mean" => self.select_reduction_impl(size),
829 "fft" => self.select_fft_impl(size),
830 _ => "Scalar",
831 };
832
833 if self.has_unified_memory() && recommendation == "Metal" {
834 format!("{recommendation} (Unified Memory)")
835 } else {
836 recommendation.to_string()
837 }
838 }
839}
840
841impl Default for AutoOptimizer {
842 fn default() -> Self {
843 Self::new()
844 }
845}
846
847#[cfg(test)]
848mod tests {
849 use super::*;
850 use ndarray::arr1;
851
852 #[test]
853 fn test_simd_unified_ops_f32() {
854 let a = arr1(&[1.0f32, 2.0, 3.0, 4.0]);
855 let b = arr1(&[5.0f32, 6.0, 7.0, 8.0]);
856
857 let sum = f32::simd_add(&a.view(), &b.view());
858 assert_eq!(sum, arr1(&[6.0f32, 8.0, 10.0, 12.0]));
859
860 let product = f32::simd_mul(&a.view(), &b.view());
861 assert_eq!(product, arr1(&[5.0f32, 12.0, 21.0, 32.0]));
862
863 let dot = f32::simd_dot(&a.view(), &b.view());
864 assert_eq!(dot, 70.0);
865 }
866
867 #[test]
868 fn test_platform_capabilities() {
869 let caps = PlatformCapabilities::detect();
870 println!("{}", caps.summary());
871 }
872
873 #[test]
874 fn test_auto_optimizer() {
875 let optimizer = AutoOptimizer::new();
876
877 assert!(!optimizer.should_use_gpu(100));
879
880 let large_size = 100000;
882 if optimizer.capabilities.gpu_available {
883 assert!(optimizer.should_use_gpu(large_size));
884 }
885 }
886}