1use std::cmp::min;
17use std::collections::HashMap;
18use std::time::Instant;
19
20use scirs2_core::parallel_ops::*;
22use torsh_core::{
23 dtype::FloatElement,
24 error::{Result, TorshError},
25};
26
27#[derive(Debug, Clone)]
32pub struct AlgorithmConfig {
33 pub enable_adaptive_selection: bool,
35 pub min_size_for_advanced: usize,
37 pub l1_cache_size: usize,
39 pub l2_cache_size: usize,
40 pub l3_cache_size: usize,
41 pub enable_operation_fusion: bool,
43 pub max_fusion_chain: usize,
45 pub enable_numerical_stability: bool,
47 pub scheduling_strategy: SchedulingStrategy,
49}
50
51impl Default for AlgorithmConfig {
52 fn default() -> Self {
53 Self {
54 enable_adaptive_selection: true,
55 min_size_for_advanced: 64,
56 l1_cache_size: 32 * 1024, l2_cache_size: 256 * 1024, l3_cache_size: 8 * 1024 * 1024, enable_operation_fusion: true,
60 max_fusion_chain: 8,
61 enable_numerical_stability: true,
62 scheduling_strategy: SchedulingStrategy::WorkStealing,
63 }
64 }
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
69pub enum SchedulingStrategy {
70 Static,
72 WorkStealing,
74 Adaptive,
76 NumaAware,
78}
79
80pub struct AlgorithmicOptimizer {
82 config: AlgorithmConfig,
83 performance_history: std::sync::RwLock<HashMap<OperationSignature, PerformanceMetrics>>,
85}
86
87impl AlgorithmicOptimizer {
88 pub fn new() -> Self {
90 Self::with_config(AlgorithmConfig::default())
91 }
92
93 pub fn with_config(config: AlgorithmConfig) -> Self {
95 Self {
96 config,
97 performance_history: std::sync::RwLock::new(HashMap::new()),
98 }
99 }
100
101 pub fn optimized_matmul<T>(
103 &self,
104 a: &[T],
105 b: &[T],
106 c: &mut [T],
107 m: usize, k: usize, n: usize, ) -> Result<()>
111 where
112 T: FloatElement + Send + Sync + std::ops::AddAssign,
113 {
114 #[cfg(feature = "profiling")]
115 {
116 }
118 let signature = OperationSignature::MatMul { m, k, n };
119
120 let algorithm = self.select_matmul_algorithm(&signature);
122
123 let start_time = Instant::now();
124
125 match algorithm {
126 MatMulAlgorithm::Naive => self.naive_matmul(a, b, c, m, k, n)?,
127 MatMulAlgorithm::Blocked => self.blocked_matmul(a, b, c, m, k, n)?,
128 MatMulAlgorithm::Strassen => self.strassen_matmul(a, b, c, m, k, n)?,
129 MatMulAlgorithm::CacheOblivious => self.cache_oblivious_matmul(a, b, c, m, k, n)?,
130 MatMulAlgorithm::Parallel => self.parallel_matmul(a, b, c, m, k, n)?,
131 }
132
133 let duration = start_time.elapsed();
135 self.record_performance(signature, algorithm, duration);
136
137 Ok(())
138 }
139
140 fn select_matmul_algorithm(&self, signature: &OperationSignature) -> MatMulAlgorithm {
142 if !self.config.enable_adaptive_selection {
143 return MatMulAlgorithm::Blocked; }
145
146 if let Some(metrics) = self
148 .performance_history
149 .read()
150 .expect("lock should not be poisoned")
151 .get(signature)
152 {
153 return metrics
154 .best_algorithm
155 .clone()
156 .unwrap_or(MatMulAlgorithm::Blocked);
157 }
158
159 match signature {
161 OperationSignature::MatMul { m, k, n } => {
162 let total_size = m * k * n;
163
164 if total_size < 1000 {
165 MatMulAlgorithm::Naive
166 } else if total_size < 10000 {
167 MatMulAlgorithm::Blocked
168 } else if *m >= 1024 && *k >= 1024 && *n >= 1024 {
169 MatMulAlgorithm::Strassen
170 } else if total_size > 100000 {
171 MatMulAlgorithm::Parallel
172 } else {
173 MatMulAlgorithm::CacheOblivious
174 }
175 }
176 }
177 }
178
179 fn naive_matmul<T>(
181 &self,
182 a: &[T],
183 b: &[T],
184 c: &mut [T],
185 m: usize,
186 k: usize,
187 n: usize,
188 ) -> Result<()>
189 where
190 T: FloatElement + std::ops::AddAssign,
191 {
192 for i in 0..m {
193 for j in 0..n {
194 let mut sum = <T as torsh_core::TensorElement>::zero();
195 for l in 0..k {
196 sum += a[i * k + l] * b[l * n + j];
197 }
198 c[i * n + j] = sum;
199 }
200 }
201 Ok(())
202 }
203
204 fn blocked_matmul<T>(
206 &self,
207 a: &[T],
208 b: &[T],
209 c: &mut [T],
210 m: usize,
211 k: usize,
212 n: usize,
213 ) -> Result<()>
214 where
215 T: FloatElement + std::ops::AddAssign,
216 {
217 let block_size = self.calculate_optimal_block_size(m, k, n);
219
220 for i_block in (0..m).step_by(block_size) {
221 for j_block in (0..n).step_by(block_size) {
222 for k_block in (0..k).step_by(block_size) {
223 let i_end = min(i_block + block_size, m);
224 let j_end = min(j_block + block_size, n);
225 let k_end = min(k_block + block_size, k);
226
227 for i in i_block..i_end {
229 for j in j_block..j_end {
230 let mut sum = if k_block == 0 {
231 <T as torsh_core::TensorElement>::zero()
232 } else {
233 c[i * n + j]
234 };
235 for l in k_block..k_end {
236 sum += a[i * k + l] * b[l * n + j];
237 }
238 c[i * n + j] = sum;
239 }
240 }
241 }
242 }
243 }
244 Ok(())
245 }
246
247 fn strassen_matmul<T>(
249 &self,
250 a: &[T],
251 b: &[T],
252 c: &mut [T],
253 m: usize,
254 k: usize,
255 n: usize,
256 ) -> Result<()>
257 where
258 T: FloatElement + Send + Sync + std::ops::AddAssign,
259 {
260 if m != k || k != n || m < 128 {
262 return self.blocked_matmul(a, b, c, m, k, n);
263 }
264
265 self.strassen_recursive(a, b, c, m, 0, 0, 0, 0, 0, 0)
266 }
267
268 fn strassen_recursive<T>(
270 &self,
271 a: &[T],
272 b: &[T],
273 c: &mut [T],
274 n: usize,
275 a_row: usize,
276 a_col: usize,
277 b_row: usize,
278 b_col: usize,
279 c_row: usize,
280 c_col: usize,
281 ) -> Result<()>
282 where
283 T: FloatElement + Send + Sync + std::ops::AddAssign,
284 {
285 if n <= 64 {
286 for i in 0..n {
288 for j in 0..n {
289 let mut sum = <T as torsh_core::TensorElement>::zero();
290 for k in 0..n {
291 let a_val = a[(a_row + i) * n + (a_col + k)];
292 let b_val = b[(b_row + k) * n + (b_col + j)];
293 sum += a_val * b_val;
294 }
295 c[(c_row + i) * n + (c_col + j)] = sum;
296 }
297 }
298 return Ok(());
299 }
300
301 let half = n / 2;
302
303 let temp_size = half * half;
305 let mut m1 = vec![<T as torsh_core::TensorElement>::zero(); temp_size];
306 let mut m2 = vec![<T as torsh_core::TensorElement>::zero(); temp_size];
307 let mut m3 = vec![<T as torsh_core::TensorElement>::zero(); temp_size];
308 let mut m4 = vec![<T as torsh_core::TensorElement>::zero(); temp_size];
309 let mut m5 = vec![<T as torsh_core::TensorElement>::zero(); temp_size];
310 let mut m6 = vec![<T as torsh_core::TensorElement>::zero(); temp_size];
311 let mut m7 = vec![<T as torsh_core::TensorElement>::zero(); temp_size];
312
313 let mut temp_a = vec![<T as torsh_core::TensorElement>::zero(); temp_size];
315 let mut temp_b = vec![<T as torsh_core::TensorElement>::zero(); temp_size];
316
317 let add_quadrants = |temp: &mut [T],
319 quad1_row: usize,
320 quad1_col: usize,
321 quad2_row: usize,
322 quad2_col: usize,
323 source: &[T]| {
324 for i in 0..half {
325 for j in 0..half {
326 let val1 = source[(quad1_row + i) * n + (quad1_col + j)];
327 let val2 = source[(quad2_row + i) * n + (quad2_col + j)];
328 temp[i * half + j] = val1 + val2;
329 }
330 }
331 };
332
333 let sub_quadrants = |temp: &mut [T],
335 quad1_row: usize,
336 quad1_col: usize,
337 quad2_row: usize,
338 quad2_col: usize,
339 source: &[T]| {
340 for i in 0..half {
341 for j in 0..half {
342 let val1 = source[(quad1_row + i) * n + (quad1_col + j)];
343 let val2 = source[(quad2_row + i) * n + (quad2_col + j)];
344 temp[i * half + j] = val1 - val2;
345 }
346 }
347 };
348
349 add_quadrants(&mut temp_a, a_row, a_col, a_row + half, a_col + half, a);
351 add_quadrants(&mut temp_b, b_row, b_col, b_row + half, b_col + half, b);
352 self.blocked_matmul(&temp_a, &temp_b, &mut m1, half, half, half)?;
353
354 add_quadrants(
356 &mut temp_a,
357 a_row + half,
358 a_col,
359 a_row + half,
360 a_col + half,
361 a,
362 );
363 for i in 0..half {
364 for j in 0..half {
365 temp_b[i * half + j] = b[(b_row + i) * n + (b_col + j)];
366 }
367 }
368 self.blocked_matmul(&temp_a, &temp_b, &mut m2, half, half, half)?;
369
370 for i in 0..half {
372 for j in 0..half {
373 temp_a[i * half + j] = a[(a_row + i) * n + (a_col + j)];
374 }
375 }
376 sub_quadrants(
377 &mut temp_b,
378 b_row,
379 b_col + half,
380 b_row + half,
381 b_col + half,
382 b,
383 );
384 self.blocked_matmul(&temp_a, &temp_b, &mut m3, half, half, half)?;
385
386 for i in 0..half {
388 for j in 0..half {
389 temp_a[i * half + j] = a[(a_row + half + i) * n + (a_col + half + j)];
390 }
391 }
392 sub_quadrants(&mut temp_b, b_row + half, b_col, b_row, b_col, b);
393 self.blocked_matmul(&temp_a, &temp_b, &mut m4, half, half, half)?;
394
395 add_quadrants(&mut temp_a, a_row, a_col, a_row, a_col + half, a);
397 for i in 0..half {
398 for j in 0..half {
399 temp_b[i * half + j] = b[(b_row + half + i) * n + (b_col + half + j)];
400 }
401 }
402 self.blocked_matmul(&temp_a, &temp_b, &mut m5, half, half, half)?;
403
404 sub_quadrants(&mut temp_a, a_row + half, a_col, a_row, a_col, a);
406 add_quadrants(&mut temp_b, b_row, b_col, b_row, b_col + half, b);
407 self.blocked_matmul(&temp_a, &temp_b, &mut m6, half, half, half)?;
408
409 sub_quadrants(
411 &mut temp_a,
412 a_row,
413 a_col + half,
414 a_row + half,
415 a_col + half,
416 a,
417 );
418 add_quadrants(
419 &mut temp_b,
420 b_row + half,
421 b_col,
422 b_row + half,
423 b_col + half,
424 b,
425 );
426 self.blocked_matmul(&temp_a, &temp_b, &mut m7, half, half, half)?;
427
428 for i in 0..half {
431 for j in 0..half {
432 c[(c_row + i) * n + (c_col + j)] =
433 m1[i * half + j] + m4[i * half + j] - m5[i * half + j] + m7[i * half + j];
434 }
435 }
436
437 for i in 0..half {
439 for j in 0..half {
440 c[(c_row + i) * n + (c_col + half + j)] = m3[i * half + j] + m5[i * half + j];
441 }
442 }
443
444 for i in 0..half {
446 for j in 0..half {
447 c[(c_row + half + i) * n + (c_col + j)] = m2[i * half + j] + m4[i * half + j];
448 }
449 }
450
451 for i in 0..half {
453 for j in 0..half {
454 c[(c_row + half + i) * n + (c_col + half + j)] =
455 m1[i * half + j] - m2[i * half + j] + m3[i * half + j] + m6[i * half + j];
456 }
457 }
458
459 Ok(())
460 }
461
462 fn cache_oblivious_matmul<T>(
464 &self,
465 a: &[T],
466 b: &[T],
467 c: &mut [T],
468 m: usize,
469 k: usize,
470 n: usize,
471 ) -> Result<()>
472 where
473 T: FloatElement + std::ops::AddAssign,
474 {
475 self.cache_oblivious_recursive(a, b, c, m, k, n, 0, 0, 0, 0, 0, 0)
476 }
477
478 fn cache_oblivious_recursive<T>(
480 &self,
481 a: &[T],
482 b: &[T],
483 c: &mut [T],
484 m: usize,
485 k: usize,
486 n: usize,
487 a_row: usize,
488 a_col: usize,
489 b_row: usize,
490 b_col: usize,
491 c_row: usize,
492 c_col: usize,
493 ) -> Result<()>
494 where
495 T: FloatElement + std::ops::AddAssign,
496 {
497 if m <= 32 || k <= 32 || n <= 32 {
499 return self
500 .naive_matmul_region(a, b, c, m, k, n, a_row, a_col, b_row, b_col, c_row, c_col);
501 }
502
503 if m >= k && m >= n {
505 let m1 = m / 2;
506 let m2 = m - m1;
507
508 self.cache_oblivious_recursive(
510 a, b, c, m1, k, n, a_row, a_col, b_row, b_col, c_row, c_col,
511 )?;
512
513 self.cache_oblivious_recursive(
515 a,
516 b,
517 c,
518 m2,
519 k,
520 n,
521 a_row + m1,
522 a_col,
523 b_row,
524 b_col,
525 c_row + m1,
526 c_col,
527 )?;
528 } else if k >= n {
529 let k1 = k / 2;
530 let k2 = k - k1;
531
532 self.cache_oblivious_recursive(
534 a, b, c, m, k1, n, a_row, a_col, b_row, b_col, c_row, c_col,
535 )?;
536
537 self.cache_oblivious_recursive(
538 a,
539 b,
540 c,
541 m,
542 k2,
543 n,
544 a_row,
545 a_col + k1,
546 b_row + k1,
547 b_col,
548 c_row,
549 c_col,
550 )?;
551 } else {
552 let n1 = n / 2;
553 let n2 = n - n1;
554
555 self.cache_oblivious_recursive(
557 a, b, c, m, k, n1, a_row, a_col, b_row, b_col, c_row, c_col,
558 )?;
559
560 self.cache_oblivious_recursive(
562 a,
563 b,
564 c,
565 m,
566 k,
567 n2,
568 a_row,
569 a_col,
570 b_row,
571 b_col + n1,
572 c_row,
573 c_col + n1,
574 )?;
575 }
576
577 Ok(())
578 }
579
580 fn naive_matmul_region<T>(
582 &self,
583 a: &[T],
584 b: &[T],
585 c: &mut [T],
586 m: usize,
587 k: usize,
588 n: usize,
589 a_row: usize,
590 a_col: usize,
591 b_row: usize,
592 b_col: usize,
593 c_row: usize,
594 c_col: usize,
595 ) -> Result<()>
596 where
597 T: FloatElement + std::ops::AddAssign,
598 {
599 for i in 0..m {
600 for j in 0..n {
601 let mut sum = <T as torsh_core::TensorElement>::zero();
602 for l in 0..k {
603 let a_idx = (a_row + i) * k + (a_col + l);
604 let b_idx = (b_row + l) * n + (b_col + j);
605 sum += a[a_idx] * b[b_idx];
606 }
607 let c_idx = (c_row + i) * n + (c_col + j);
608 c[c_idx] += sum; }
610 }
611 Ok(())
612 }
613
614 fn parallel_matmul<T>(
616 &self,
617 a: &[T],
618 b: &[T],
619 c: &mut [T],
620 m: usize,
621 k: usize,
622 n: usize,
623 ) -> Result<()>
624 where
625 T: FloatElement + Send + Sync + std::ops::AddAssign,
626 {
627 let num_cores = get_num_threads();
628 let block_size = self.calculate_optimal_block_size(m, k, n);
629
630 let total_operations = m * k * n;
632 let min_work_per_core = 100_000; let should_parallelize = num_cores > 1 && total_operations > min_work_per_core * num_cores;
634
635 if !should_parallelize {
636 return self.blocked_matmul(a, b, c, m, k, n);
638 }
639
640 let work_items: Vec<_> = (0..m)
642 .step_by(block_size)
643 .flat_map(|i| (0..n).step_by(block_size).map(move |j| (i, j)))
644 .collect();
645
646 let results: Result<Vec<_>> = parallel_map_result(&work_items, |&(i_block, j_block)| {
648 let i_end = min(i_block + block_size, m);
649 let j_end = min(j_block + block_size, n);
650
651 let mut block_results = Vec::new();
652 for i in i_block..i_end {
653 for j in j_block..j_end {
654 let mut sum = <T as torsh_core::TensorElement>::zero();
655 for l in 0..k {
656 sum += a[i * k + l] * b[l * n + j];
657 }
658 let idx = i * n + j;
659 block_results.push((idx, sum));
660 }
661 }
662 Ok(block_results)
663 });
664
665 for block_results in results? {
667 for (idx, value) in block_results {
668 c[idx] = value;
669 }
670 }
671
672 Ok(())
673 }
674
675 fn calculate_optimal_block_size(&self, m: usize, k: usize, n: usize) -> usize {
677 let element_size = std::mem::size_of::<f32>(); let l1_elements = self.config.l1_cache_size / element_size;
683
684 let cache_optimal = (l1_elements as f64 / 3.0).sqrt() as usize;
687
688 let dim_optimal = m.min(k).min(n);
690
691 let optimal_block = cache_optimal.min(dim_optimal);
693
694 let clamped = optimal_block.clamp(16, 256);
696
697 let log2 = (clamped as f64).log2().round() as u32;
699 2usize.pow(log2).min(256)
700 }
701
702 fn record_performance(
704 &self,
705 signature: OperationSignature,
706 algorithm: MatMulAlgorithm,
707 duration: std::time::Duration,
708 ) {
709 let mut history = self
710 .performance_history
711 .write()
712 .expect("lock should not be poisoned");
713 let metrics = history
714 .entry(signature)
715 .or_insert_with(PerformanceMetrics::default);
716
717 metrics.update_performance(algorithm, duration);
718 }
719
720 pub fn optimized_conv2d<T>(
722 &self,
723 input: &[T],
724 kernel: &[T],
725 output: &mut [T],
726 input_h: usize,
727 input_w: usize,
728 kernel_h: usize,
729 kernel_w: usize,
730 stride: usize,
731 padding: usize,
732 ) -> Result<()>
733 where
734 T: FloatElement + Send + Sync + std::ops::AddAssign,
735 {
736 #[cfg(feature = "profiling")]
737 {
738 }
740
741 let output_h = (input_h + 2 * padding - kernel_h) / stride + 1;
743 let output_w = (input_w + 2 * padding - kernel_w) / stride + 1;
744 let expected_output_size = output_h * output_w;
745
746 if output.len() < expected_output_size {
748 return Err(torsh_core::error::TorshError::InvalidShape(format!(
749 "Output buffer too small: expected at least {} ({}x{}) elements, got {}",
750 expected_output_size,
751 output_h,
752 output_w,
753 output.len()
754 )));
755 }
756
757 if kernel_h * kernel_w <= 9 && input_h * input_w > 10000 {
773 self.direct_conv2d(
775 input, kernel, output, input_h, input_w, kernel_h, kernel_w, stride, padding,
776 )
777 } else if kernel_h >= 7 && kernel_w >= 7 {
778 self.fft_conv2d(
780 input, kernel, output, input_h, input_w, kernel_h, kernel_w, stride, padding,
781 )
782 } else {
783 self.winograd_conv2d(
785 input, kernel, output, input_h, input_w, kernel_h, kernel_w, stride, padding,
786 )
787 }
788 }
789
790 fn direct_conv2d<T>(
792 &self,
793 input: &[T],
794 kernel: &[T],
795 output: &mut [T],
796 input_h: usize,
797 input_w: usize,
798 kernel_h: usize,
799 kernel_w: usize,
800 stride: usize,
801 padding: usize,
802 ) -> Result<()>
803 where
804 T: FloatElement + Send + Sync + std::ops::AddAssign,
805 {
806 let output_h = (input_h + 2 * padding - kernel_h) / stride + 1;
807 let output_w = (input_w + 2 * padding - kernel_w) / stride + 1;
808
809 let output_positions: Vec<_> = (0..output_h)
811 .flat_map(|out_y| (0..output_w).map(move |out_x| (out_y, out_x)))
812 .collect();
813
814 let results: Vec<_> = parallel_map_collect(output_positions, |(out_y, out_x)| {
815 let mut sum = <T as torsh_core::TensorElement>::zero();
816
817 for ky in 0..kernel_h {
818 for kx in 0..kernel_w {
819 let in_y = out_y * stride + ky;
820 let in_x = out_x * stride + kx;
821
822 if in_y >= padding
823 && in_y < input_h + padding
824 && in_x >= padding
825 && in_x < input_w + padding
826 {
827 let input_y = in_y - padding;
828 let input_x = in_x - padding;
829
830 if input_y < input_h && input_x < input_w {
831 sum += input[input_y * input_w + input_x] * kernel[ky * kernel_w + kx];
832 }
833 }
834 }
835 }
836
837 (out_y * output_w + out_x, sum)
838 });
839
840 for (idx, value) in results {
842 output[idx] = value;
843 }
844
845 Ok(())
846 }
847
848 fn fft_conv2d<T>(
850 &self,
851 input: &[T],
852 kernel: &[T],
853 output: &mut [T],
854 input_h: usize,
855 input_w: usize,
856 kernel_h: usize,
857 kernel_w: usize,
858 stride: usize,
859 padding: usize,
860 ) -> Result<()>
861 where
862 T: FloatElement + std::ops::AddAssign,
863 {
864 self.direct_conv2d(
867 input, kernel, output, input_h, input_w, kernel_h, kernel_w, stride, padding,
868 )
869 }
870
871 fn winograd_conv2d<T>(
873 &self,
874 input: &[T],
875 kernel: &[T],
876 output: &mut [T],
877 input_h: usize,
878 input_w: usize,
879 kernel_h: usize,
880 kernel_w: usize,
881 stride: usize,
882 padding: usize,
883 ) -> Result<()>
884 where
885 T: FloatElement + std::ops::AddAssign,
886 {
887 self.direct_conv2d(
890 input, kernel, output, input_h, input_w, kernel_h, kernel_w, stride, padding,
891 )
892 }
893
894 pub fn execute_fused_operations<T>(
896 &self,
897 operations: &[FusedOperation<T>],
898 inputs: &[&[T]],
899 outputs: &mut [&mut [T]],
900 ) -> Result<()>
901 where
902 T: FloatElement + Send + Sync + std::ops::AddAssign,
903 {
904 if !self.config.enable_operation_fusion {
905 return Err(TorshError::InvalidArgument(
906 "Operation fusion disabled".to_string(),
907 ));
908 }
909
910 #[cfg(feature = "profiling")]
911 {
912 }
914
915 let compiled = self.compile_fusion(operations)?;
917 compiled.execute(inputs, outputs)
918 }
919
920 fn compile_fusion<T>(&self, operations: &[FusedOperation<T>]) -> Result<CompiledFusion<T>>
922 where
923 T: FloatElement + std::ops::AddAssign,
924 {
925 let plan = ExecutionPlan {
927 operations: operations.to_vec(),
928 optimization_level: OptimizationLevel::Aggressive,
929 };
930
931 Ok(CompiledFusion {
932 plan,
933 estimated_flops: self.estimate_fusion_flops(operations),
934 })
935 }
936
937 fn estimate_fusion_flops<T>(&self, operations: &[FusedOperation<T>]) -> usize
939 where
940 T: FloatElement + std::ops::AddAssign,
941 {
942 operations.len() * 1000 }
945
946 pub fn get_performance_stats(&self) -> AlgorithmPerformanceStats {
948 let history = self
949 .performance_history
950 .read()
951 .expect("lock should not be poisoned");
952
953 let mut total_operations = 0;
954 let mut algorithm_counts = HashMap::new();
955
956 for metrics in history.values() {
957 total_operations += metrics.execution_count;
958 if let Some(ref algorithm) = metrics.best_algorithm {
959 *algorithm_counts.entry(algorithm.clone()).or_insert(0) += 1;
960 }
961 }
962
963 AlgorithmPerformanceStats {
964 total_operations,
965 unique_operation_signatures: history.len(),
966 algorithm_distribution: algorithm_counts,
967 average_speedup: self.calculate_average_speedup(&history),
968 }
969 }
970
971 fn calculate_average_speedup(
973 &self,
974 history: &HashMap<OperationSignature, PerformanceMetrics>,
975 ) -> f64 {
976 if history.is_empty() {
977 return 1.0;
978 }
979
980 let speedups: Vec<f64> = history
981 .values()
982 .filter_map(|metrics| metrics.best_speedup)
983 .collect();
984
985 if speedups.is_empty() {
986 1.0
987 } else {
988 speedups.iter().sum::<f64>() / speedups.len() as f64
989 }
990 }
991}
992
993impl Default for AlgorithmicOptimizer {
994 fn default() -> Self {
995 Self::new()
996 }
997}
998
999#[derive(Debug, Clone, Hash, PartialEq, Eq)]
1001enum OperationSignature {
1002 MatMul { m: usize, k: usize, n: usize },
1003}
1004
1005#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1007pub enum MatMulAlgorithm {
1008 Naive,
1009 Blocked,
1010 Strassen,
1011 CacheOblivious,
1012 Parallel,
1013}
1014
1015#[derive(Debug, Clone)]
1017struct PerformanceMetrics {
1018 execution_count: usize,
1019 algorithm_timings: HashMap<MatMulAlgorithm, Vec<std::time::Duration>>,
1020 best_algorithm: Option<MatMulAlgorithm>,
1021 best_speedup: Option<f64>,
1022}
1023
1024impl Default for PerformanceMetrics {
1025 fn default() -> Self {
1026 Self {
1027 execution_count: 0,
1028 algorithm_timings: HashMap::new(),
1029 best_algorithm: None,
1030 best_speedup: None,
1031 }
1032 }
1033}
1034
1035impl PerformanceMetrics {
1036 fn update_performance(&mut self, algorithm: MatMulAlgorithm, duration: std::time::Duration) {
1037 self.execution_count += 1;
1038 self.algorithm_timings
1039 .entry(algorithm.clone())
1040 .or_insert_with(Vec::new)
1041 .push(duration);
1042
1043 let avg_duration = self.average_duration(&algorithm);
1045 let current_best_duration = self
1046 .best_algorithm
1047 .as_ref()
1048 .map(|alg| self.average_duration(alg))
1049 .unwrap_or(std::time::Duration::from_secs(u64::MAX));
1050
1051 if avg_duration < current_best_duration {
1052 let speedup = current_best_duration.as_secs_f64() / avg_duration.as_secs_f64();
1053 self.best_algorithm = Some(algorithm);
1054 self.best_speedup = Some(speedup);
1055 }
1056 }
1057
1058 fn average_duration(&self, algorithm: &MatMulAlgorithm) -> std::time::Duration {
1059 static EMPTY_VEC: Vec<std::time::Duration> = Vec::new();
1060 let timings = self.algorithm_timings.get(algorithm).unwrap_or(&EMPTY_VEC);
1061 if timings.is_empty() {
1062 return std::time::Duration::from_secs(u64::MAX);
1063 }
1064
1065 let total_nanos: u128 = timings.iter().map(|d| d.as_nanos()).sum();
1066 std::time::Duration::from_nanos((total_nanos / timings.len() as u128) as u64)
1067 }
1068}
1069
1070#[derive(Debug, Clone)]
1072pub enum FusedOperation<T> {
1073 ElementwiseAdd {
1074 alpha: T,
1075 },
1076 ElementwiseMul {
1077 scale: T,
1078 },
1079 ReLU,
1080 Sigmoid,
1081 MatMul {
1082 transpose_a: bool,
1083 transpose_b: bool,
1084 },
1085}
1086
1087#[allow(dead_code)]
1089#[derive(Debug, Clone, Hash, PartialEq, Eq)]
1090struct FusionSignature {
1091 operation_types: Vec<String>,
1092 tensor_shapes: Vec<Vec<usize>>,
1093}
1094
1095#[allow(dead_code)]
1096impl FusionSignature {
1097 fn from_operations<T>(operations: &[FusedOperation<T>]) -> Self
1098 where
1099 T: FloatElement + std::ops::AddAssign,
1100 {
1101 let operation_types = operations.iter().map(|op| format!("{:?}", op)).collect();
1102
1103 Self {
1104 operation_types,
1105 tensor_shapes: vec![], }
1107 }
1108}
1109
1110#[allow(dead_code)]
1112#[derive(Debug, Clone)]
1113struct CompiledFusion<T> {
1114 plan: ExecutionPlan<T>,
1115 estimated_flops: usize,
1116}
1117
1118impl<T> CompiledFusion<T> {
1119 fn execute(&self, inputs: &[&[T]], outputs: &mut [&mut [T]]) -> Result<()>
1120 where
1121 T: FloatElement + std::ops::AddAssign,
1122 {
1123 self.plan.execute(inputs, outputs)
1125 }
1126}
1127
1128#[allow(dead_code)]
1130#[derive(Debug, Clone)]
1131struct ExecutionPlan<T> {
1132 operations: Vec<FusedOperation<T>>,
1133 optimization_level: OptimizationLevel,
1134}
1135
1136impl<T> ExecutionPlan<T> {
1137 fn execute(&self, inputs: &[&[T]], outputs: &mut [&mut [T]]) -> Result<()>
1138 where
1139 T: FloatElement + std::ops::AddAssign,
1140 {
1141 if outputs.is_empty() || inputs.is_empty() {
1142 return Ok(());
1143 }
1144
1145 let output = outputs.get_mut(0).ok_or_else(|| {
1148 torsh_core::error::TorshError::InvalidShape("No output buffer".to_string())
1149 })?;
1150
1151 if let Some(first_input) = inputs.first() {
1153 if first_input.len() == output.len() {
1154 output.copy_from_slice(first_input);
1155 }
1156 }
1157
1158 for op in &self.operations {
1160 match op {
1161 FusedOperation::ElementwiseAdd { alpha } => {
1162 for val in output.iter_mut() {
1163 *val += *alpha;
1164 }
1165 }
1166 FusedOperation::ElementwiseMul { scale } => {
1167 for val in output.iter_mut() {
1168 *val = *val * *scale;
1169 }
1170 }
1171 FusedOperation::ReLU => {
1172 let zero = <T as torsh_core::dtype::TensorElement>::zero();
1173 for val in output.iter_mut() {
1174 if *val < zero {
1175 *val = zero;
1176 }
1177 }
1178 }
1179 FusedOperation::Sigmoid => {
1180 let one = <T as num_traits::One>::one();
1181 for val in output.iter_mut() {
1182 let exp_neg = (-*val).exp();
1184 *val = one / (one + exp_neg);
1185 }
1186 }
1187 FusedOperation::MatMul { .. } => {
1188 }
1191 }
1192 }
1193
1194 Ok(())
1195 }
1196}
1197
1198#[allow(dead_code)]
1200#[derive(Debug, Clone, Copy)]
1201enum OptimizationLevel {
1202 Conservative,
1203 Moderate,
1204 Aggressive,
1205}
1206
1207#[derive(Debug)]
1209pub struct AlgorithmPerformanceStats {
1210 pub total_operations: usize,
1211 pub unique_operation_signatures: usize,
1212 pub algorithm_distribution: HashMap<MatMulAlgorithm, usize>,
1213 pub average_speedup: f64,
1214}
1215
1216#[cfg(test)]
1217mod tests {
1218 use super::*;
1219
1220 #[test]
1221 fn test_algorithm_config_default() {
1222 let config = AlgorithmConfig::default();
1223 assert!(config.enable_adaptive_selection);
1224 assert!(config.enable_operation_fusion);
1225 assert!(config.enable_numerical_stability);
1226 }
1227
1228 #[test]
1229 fn test_algorithmic_optimizer_creation() {
1230 let optimizer = AlgorithmicOptimizer::new();
1231 let stats = optimizer.get_performance_stats();
1232
1233 assert_eq!(stats.total_operations, 0);
1234 assert_eq!(stats.unique_operation_signatures, 0);
1235 }
1236
1237 #[test]
1238 fn test_algorithm_selection() {
1239 let optimizer = AlgorithmicOptimizer::new();
1240 let signature = OperationSignature::MatMul {
1241 m: 100,
1242 k: 100,
1243 n: 100,
1244 };
1245
1246 let algorithm = optimizer.select_matmul_algorithm(&signature);
1247 assert!(matches!(algorithm, MatMulAlgorithm::Parallel));
1249 }
1250
1251 #[test]
1252 fn test_small_matrix_multiplication() {
1253 let optimizer = AlgorithmicOptimizer::new();
1254
1255 let a = vec![1.0f32, 2.0, 3.0, 4.0]; let b = vec![5.0f32, 6.0, 7.0, 8.0]; let mut c = vec![0.0f32; 4]; optimizer
1260 .optimized_matmul(&a, &b, &mut c, 2, 2, 2)
1261 .expect("optimized_matmul should succeed");
1262
1263 assert!((c[0] - 19.0).abs() < 1e-6);
1265 assert!((c[1] - 22.0).abs() < 1e-6);
1266 assert!((c[2] - 43.0).abs() < 1e-6);
1267 assert!((c[3] - 50.0).abs() < 1e-6);
1268 }
1269
1270 #[test]
1271 fn test_block_size_calculation() {
1272 let optimizer = AlgorithmicOptimizer::new();
1273 let block_size = optimizer.calculate_optimal_block_size(1000, 1000, 1000);
1274
1275 assert!(block_size >= 16);
1276 assert!(block_size <= 256);
1277 }
1278
1279 #[test]
1280 fn test_conv2d_basic() {
1281 let optimizer = AlgorithmicOptimizer::new();
1282
1283 let input = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
1285 let kernel = vec![1.0f32, 0.0, 0.0, 1.0];
1286 let mut output = vec![0.0f32; 4]; optimizer
1289 .optimized_conv2d(&input, &kernel, &mut output, 3, 3, 2, 2, 1, 0)
1290 .expect("operation should succeed");
1291
1292 assert!(output.iter().all(|&x| x >= 0.0));
1294 }
1295
1296 #[test]
1297 fn test_performance_metrics() {
1298 let mut metrics = PerformanceMetrics::default();
1299
1300 let duration = std::time::Duration::from_millis(100);
1301 metrics.update_performance(MatMulAlgorithm::Blocked, duration);
1302
1303 assert_eq!(metrics.execution_count, 1);
1304 assert!(metrics.best_algorithm.is_some());
1305 }
1306
1307 #[test]
1308 fn test_fusion_signature() {
1309 let operations = vec![
1310 FusedOperation::ElementwiseAdd { alpha: 1.0f32 },
1311 FusedOperation::ReLU,
1312 ];
1313
1314 let signature = FusionSignature::from_operations(&operations);
1315 assert_eq!(signature.operation_types.len(), 2);
1316 }
1317}