1use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::numeric::Float;
9use std::fmt::Debug;
10
11use super::{GpuCapabilities, GpuConfig, TensorCoresConfig, TensorCoresGeneration};
12use crate::error::{Result, TimeSeriesError};
13
14#[derive(Debug)]
16pub struct GpuBLAS<F: Float + Debug> {
17 #[allow(dead_code)]
18 config: GpuConfig,
19 phantom: std::marker::PhantomData<F>,
20}
21
22impl<F: Float + Debug + Clone> GpuBLAS<F> {
23 pub fn new(config: GpuConfig) -> Self {
25 Self {
26 config,
27 phantom: std::marker::PhantomData,
28 }
29 }
30
31 pub fn dot(&self, x: &Array1<F>, y: &Array1<F>) -> Result<F> {
33 if x.len() != y.len() {
34 return Err(TimeSeriesError::DimensionMismatch {
35 expected: x.len(),
36 actual: y.len(),
37 });
38 }
39
40 let n = x.len();
41 let chunk_size = self.config.batch_size;
42 let mut result = F::zero();
43
44 for chunk_start in (0..n).step_by(chunk_size) {
46 let chunk_end = (chunk_start + chunk_size).min(n);
47 let mut chunk_sum = F::zero();
48
49 for i in chunk_start..chunk_end {
51 chunk_sum = chunk_sum + x[i] * y[i];
52 }
53
54 result = result + chunk_sum;
55 }
56
57 Ok(result)
58 }
59
60 pub fn norm(&self, x: &Array1<F>) -> Result<F> {
62 let dot_product = self.dot(x, x)?;
63 Ok(dot_product.sqrt())
64 }
65
66 pub fn axpy(&self, alpha: F, x: &Array1<F>, y: &mut Array1<F>) -> Result<()> {
68 if x.len() != y.len() {
69 return Err(TimeSeriesError::DimensionMismatch {
70 expected: x.len(),
71 actual: y.len(),
72 });
73 }
74
75 let n = x.len();
76 let chunk_size = self.config.batch_size;
77
78 for chunk_start in (0..n).step_by(chunk_size) {
80 let chunk_end = (chunk_start + chunk_size).min(n);
81
82 for i in chunk_start..chunk_end {
84 y[i] = alpha * x[i] + y[i];
85 }
86 }
87
88 Ok(())
89 }
90
91 pub fn gemv(
93 &self,
94 alpha: F,
95 a: &Array2<F>,
96 x: &Array1<F>,
97 beta: F,
98 y: &mut Array1<F>,
99 ) -> Result<()> {
100 let (m, n) = a.dim();
101
102 if x.len() != n {
103 return Err(TimeSeriesError::DimensionMismatch {
104 expected: n,
105 actual: x.len(),
106 });
107 }
108
109 if y.len() != m {
110 return Err(TimeSeriesError::DimensionMismatch {
111 expected: m,
112 actual: y.len(),
113 });
114 }
115
116 let row_chunk_size = self.config.batch_size / n;
117
118 for row_chunk_start in (0..m).step_by(row_chunk_size) {
120 let row_chunk_end = (row_chunk_start + row_chunk_size).min(m);
121
122 for i in row_chunk_start..row_chunk_end {
124 let row = a.row(i);
125 let mut sum = F::zero();
126
127 for j in 0..n {
129 sum = sum + row[j] * x[j];
130 }
131
132 y[i] = alpha * sum + beta * y[i];
133 }
134 }
135
136 Ok(())
137 }
138
139 pub fn gemm(
141 &self,
142 alpha: F,
143 a: &Array2<F>,
144 b: &Array2<F>,
145 beta: F,
146 c: &mut Array2<F>,
147 ) -> Result<()> {
148 let (m, k1) = a.dim();
149 let (k2, n) = b.dim();
150 let (cm, cn) = c.dim();
151
152 if k1 != k2 {
153 return Err(TimeSeriesError::DimensionMismatch {
154 expected: k1,
155 actual: k2,
156 });
157 }
158
159 if cm != m || cn != n {
160 return Err(TimeSeriesError::DimensionMismatch {
161 expected: m * n,
162 actual: cm * cn,
163 });
164 }
165
166 let k = k1;
167 let tile_size = (self.config.batch_size as f64).sqrt() as usize;
168
169 for i_tile in (0..m).step_by(tile_size) {
171 for j_tile in (0..n).step_by(tile_size) {
172 let i_end = (i_tile + tile_size).min(m);
173 let j_end = (j_tile + tile_size).min(n);
174
175 for i in i_tile..i_end {
177 for j in j_tile..j_end {
178 let mut sum = F::zero();
179
180 for k_idx in 0..k {
182 sum = sum + a[[i, k_idx]] * b[[k_idx, j]];
183 }
184
185 c[[i, j]] = alpha * sum + beta * c[[i, j]];
186 }
187 }
188 }
189 }
190
191 Ok(())
192 }
193
194 pub fn transpose(&self, a: &Array2<F>) -> Array2<F> {
196 let (m, n) = a.dim();
197 let mut result = Array2::zeros((n, m));
198
199 let tile_size = (self.config.batch_size as f64).sqrt() as usize;
200
201 for i_tile in (0..m).step_by(tile_size) {
203 for j_tile in (0..n).step_by(tile_size) {
204 let i_end = (i_tile + tile_size).min(m);
205 let j_end = (j_tile + tile_size).min(n);
206
207 for i in i_tile..i_end {
209 for j in j_tile..j_end {
210 result[[j, i]] = a[[i, j]];
211 }
212 }
213 }
214 }
215
216 result
217 }
218
219 pub fn batch_gemm(
221 &self,
222 alpha: F,
223 a_batch: &[Array2<F>],
224 b_batch: &[Array2<F>],
225 beta: F,
226 c_batch: &mut [Array2<F>],
227 ) -> Result<()> {
228 if a_batch.len() != b_batch.len() || b_batch.len() != c_batch.len() {
229 return Err(TimeSeriesError::InvalidInput(
230 "Batch sizes must match".to_string(),
231 ));
232 }
233
234 for ((a, b), c) in a_batch.iter().zip(b_batch.iter()).zip(c_batch.iter_mut()) {
236 self.gemm(alpha, a, b, beta, c)?;
237 }
238
239 Ok(())
240 }
241}
242
243#[derive(Debug)]
245pub struct TensorCoresBLAS<F: Float + Debug> {
246 base_blas: GpuBLAS<F>,
248 tensor_config: TensorCoresConfig,
250 device_capabilities: GpuCapabilities,
252}
253
254impl<F: Float + Debug + Clone + scirs2_core::numeric::Zero + scirs2_core::numeric::One>
255 TensorCoresBLAS<F>
256{
257 pub fn new(_config: GpuConfig, devicecapabilities: GpuCapabilities) -> Result<Self> {
259 let base_blas = GpuBLAS::new(_config.clone());
260
261 if !devicecapabilities.supports_tensor_cores {
262 return Err(TimeSeriesError::NotImplemented(
263 "Device does not support tensor cores".to_string(),
264 ));
265 }
266
267 Ok(Self {
268 base_blas,
269 tensor_config: _config.tensor_cores,
270 device_capabilities: devicecapabilities,
271 })
272 }
273
274 pub fn tensor_gemm(
276 &self,
277 alpha: F,
278 a: &Array2<F>,
279 b: &Array2<F>,
280 beta: F,
281 c: &mut Array2<F>,
282 ) -> Result<()> {
283 let (m, k1) = a.dim();
284 let (k2, n) = b.dim();
285
286 if k1 != k2 {
287 return Err(TimeSeriesError::DimensionMismatch {
288 expected: k1,
289 actual: k2,
290 });
291 }
292
293 let k = k1;
294
295 if m < self.tensor_config.min_matrix_size
297 || n < self.tensor_config.min_matrix_size
298 || k < self.tensor_config.min_matrix_size
299 {
300 return self.base_blas.gemm(alpha, a, b, beta, c);
302 }
303
304 let (tile_m, tile_n, tile_k) = self.get_optimal_tile_size(m, n, k);
306
307 for i_tile in (0..m).step_by(tile_m) {
309 for j_tile in (0..n).step_by(tile_n) {
310 for k_tile in (0..k).step_by(tile_k) {
311 let i_end = (i_tile + tile_m).min(m);
312 let j_end = (j_tile + tile_n).min(n);
313 let k_end = (k_tile + tile_k).min(k);
314
315 self.process_tensor_tile(
317 alpha,
318 a,
319 b,
320 beta,
321 c,
322 (i_tile, i_end),
323 (j_tile, j_end),
324 (k_tile, k_end),
325 )?;
326 }
327 }
328 }
329
330 Ok(())
331 }
332
333 fn get_optimal_tile_size(&self, m: usize, n: usize, k: usize) -> (usize, usize, usize) {
335 if let Some(generation) = self.device_capabilities.tensor_cores_generation {
336 let supported_dims = generation.supported_matrix_dimensions();
337
338 for &(tile_m, tile_n, tile_k) in &supported_dims {
340 if tile_m <= m && tile_n <= n && tile_k <= k {
341 let scale_factor = ((m / tile_m).min(n / tile_n).min(k / tile_k)).max(1);
343 return (
344 tile_m * scale_factor,
345 tile_n * scale_factor,
346 tile_k * scale_factor,
347 );
348 }
349 }
350
351 supported_dims[0]
353 } else {
354 (32, 32, 32)
356 }
357 }
358
359 fn process_tensor_tile(
361 &self,
362 alpha: F,
363 a: &Array2<F>,
364 b: &Array2<F>,
365 beta: F,
366 c: &mut Array2<F>,
367 (i_start, i_end): (usize, usize),
368 (j_start, j_end): (usize, usize),
369 (k_start, k_end): (usize, usize),
370 ) -> Result<()> {
371 for i in i_start..i_end {
375 for j in j_start..j_end {
376 let mut sum = F::zero();
377
378 let chunk_size = 4; let chunks = (k_end - k_start) / chunk_size;
382
383 for chunk in 0..chunks {
385 let mut chunk_sum = F::zero();
386 let base_k = k_start + chunk * chunk_size;
387
388 for offset in 0..chunk_size {
389 let k_idx = base_k + offset;
390 if k_idx < k_end && k_idx < a.ncols() && k_idx < b.nrows() {
391 chunk_sum = chunk_sum + a[[i, k_idx]] * b[[k_idx, j]];
392 }
393 }
394 sum = sum + chunk_sum;
395 }
396
397 for k_idx in (k_start + chunks * chunk_size)..k_end {
399 if k_idx < a.ncols() && k_idx < b.nrows() {
400 sum = sum + a[[i, k_idx]] * b[[k_idx, j]];
401 }
402 }
403
404 if self.tensor_config.mixed_precision {
406 c[[i, j]] = alpha * sum + beta * c[[i, j]];
409 } else {
410 c[[i, j]] = alpha * sum + beta * c[[i, j]];
411 }
412 }
413 }
414
415 Ok(())
416 }
417
418 pub fn mixed_precision_gemm(
420 &self,
421 alpha: F,
422 a: &Array2<F>,
423 b: &Array2<F>,
424 beta: F,
425 c: &mut Array2<F>,
426 ) -> Result<()> {
427 if !self.tensor_config.mixed_precision {
428 return self.tensor_gemm(alpha, a, b, beta, c);
429 }
430
431 let scaled_alpha = alpha * F::from(self.tensor_config.loss_scale).unwrap();
438
439 self.tensor_gemm(scaled_alpha, a, b, beta, c)?;
441
442 let unscale_factor = F::one() / F::from(self.tensor_config.loss_scale).unwrap();
444 for elem in c.iter_mut() {
445 *elem = *elem * unscale_factor;
446 }
447
448 Ok(())
449 }
450
451 pub fn batch_tensor_gemm(
453 &self,
454 alpha: F,
455 a_batch: &[Array2<F>],
456 b_batch: &[Array2<F>],
457 beta: F,
458 c_batch: &mut [Array2<F>],
459 ) -> Result<()> {
460 if a_batch.len() != b_batch.len() || b_batch.len() != c_batch.len() {
461 return Err(TimeSeriesError::InvalidInput(
462 "Batch sizes must match".to_string(),
463 ));
464 }
465
466 for ((a, b), c) in a_batch.iter().zip(b_batch.iter()).zip(c_batch.iter_mut()) {
468 self.tensor_gemm(alpha, a, b, beta, c)?;
469 }
470
471 Ok(())
472 }
473
474 pub fn tensor_convolution_gemm(
476 &self,
477 input: &Array2<F>,
478 kernel: &Array2<F>,
479 stride: usize,
480 ) -> Result<Array2<F>> {
481 let (input_height, input_width) = input.dim();
482 let (kernel_height, kernel_width) = kernel.dim();
483
484 let output_height = (input_height - kernel_height) / stride + 1;
485 let output_width = (input_width - kernel_width) / stride + 1;
486
487 let col_matrix = self.im2col_transform(input, kernel_height, kernel_width, stride)?;
489 let kernel_view = kernel.view();
490 let kernel_matrix = kernel_view
491 .to_shape((1, kernel_height * kernel_width))
492 .unwrap();
493
494 let mut output_matrix = Array2::zeros((1, output_height * output_width));
495
496 self.tensor_gemm(
498 F::one(),
499 &kernel_matrix.to_owned(),
500 &col_matrix,
501 F::zero(),
502 &mut output_matrix,
503 )?;
504
505 Ok(output_matrix
507 .to_shape((output_height, output_width))
508 .unwrap()
509 .to_owned())
510 }
511
512 fn im2col_transform(
514 &self,
515 input: &Array2<F>,
516 kernel_height: usize,
517 kernel_width: usize,
518 stride: usize,
519 ) -> Result<Array2<F>> {
520 let (input_height, input_width) = input.dim();
521 let output_height = (input_height - kernel_height) / stride + 1;
522 let output_width = (input_width - kernel_width) / stride + 1;
523
524 let mut col_matrix =
525 Array2::zeros((kernel_height * kernel_width, output_height * output_width));
526
527 let mut col_idx = 0;
528 for out_y in 0..output_height {
529 for out_x in 0..output_width {
530 let mut row_idx = 0;
531 for ky in 0..kernel_height {
532 for kx in 0..kernel_width {
533 let input_y = out_y * stride + ky;
534 let input_x = out_x * stride + kx;
535
536 if input_y < input_height && input_x < input_width {
537 col_matrix[[row_idx, col_idx]] = input[[input_y, input_x]];
538 }
539 row_idx += 1;
540 }
541 }
542 col_idx += 1;
543 }
544 }
545
546 Ok(col_matrix)
547 }
548
549 pub fn can_use_tensor_cores(&self, m: usize, n: usize, k: usize) -> bool {
551 if !self.tensor_config.enabled || !self.device_capabilities.supports_tensor_cores {
552 return false;
553 }
554
555 if m < self.tensor_config.min_matrix_size
557 || n < self.tensor_config.min_matrix_size
558 || k < self.tensor_config.min_matrix_size
559 {
560 return false;
561 }
562
563 if let Some(generation) = self.device_capabilities.tensor_cores_generation {
565 let supported_dims = generation.supported_matrix_dimensions();
566 for &(tile_m, tile_n, tile_k) in &supported_dims {
567 if m.is_multiple_of(tile_m) && n.is_multiple_of(tile_n) && k.is_multiple_of(tile_k)
568 {
569 return true;
570 }
571 }
572 }
573
574 false
575 }
576
577 pub fn estimate_tensor_performance(&self, m: usize, n: usize, k: usize) -> Option<f64> {
579 if !self.can_use_tensor_cores(m, n, k) {
580 return None;
581 }
582
583 if let Some(peak_tops) = self.device_capabilities.tensor_performance {
584 let total_ops = 2.0 * m as f64 * n as f64 * k as f64; let efficiency = self.estimate_efficiency(m, n, k);
587 let estimated_tops = peak_tops * efficiency;
588
589 Some(total_ops / (estimated_tops * 1e12)) } else {
591 None
592 }
593 }
594
595 fn estimate_efficiency(&self, m: usize, n: usize, k: usize) -> f64 {
597 if let Some(generation) = self.device_capabilities.tensor_cores_generation {
598 let (opt_m, opt_n, opt_k) = self.get_optimal_tile_size(m, n, k);
599
600 let m_efficiency = (m % opt_m) as f64 / opt_m as f64;
602 let n_efficiency = (n % opt_n) as f64 / opt_n as f64;
603 let k_efficiency = (k % opt_k) as f64 / opt_k as f64;
604
605 let alignment_efficiency =
606 (1.0 - m_efficiency) * (1.0 - n_efficiency) * (1.0 - k_efficiency);
607
608 let base_efficiency = match generation {
610 TensorCoresGeneration::V1 => 0.7,
611 TensorCoresGeneration::V2 => 0.8,
612 TensorCoresGeneration::V3 => 0.9,
613 TensorCoresGeneration::V4 => 0.95,
614 };
615
616 base_efficiency * alignment_efficiency.max(0.5) } else {
618 0.5 }
620 }
621}