1use crate::error::{ScirsError, ScirsResult};
7use scirs2_core::gpu::{GpuContext, GpuKernelHandle};
8use scirs2_core::ndarray::{Array1, Array2};
9use std::sync::Arc;
10
11#[derive(Debug, Clone)]
13pub struct TensorCoreOptimizationConfig {
14 pub mixed_precision: bool,
16 pub tile_size: usize,
18 pub use_amp: bool,
20 pub loss_scale: f32,
22 pub gradient_clip_threshold: Option<f32>,
24}
25
26impl Default for TensorCoreOptimizationConfig {
27 fn default() -> Self {
28 Self {
29 mixed_precision: true,
30 tile_size: 16, use_amp: true,
32 loss_scale: 65536.0,
33 gradient_clip_threshold: Some(1.0),
34 }
35 }
36}
37
38pub struct TensorCoreOptimizer {
40 context: Arc<GpuContext>,
41 config: TensorCoreOptimizationConfig,
42 gemm_kernel: GpuKernelHandle,
43 batch_gemm_kernel: GpuKernelHandle,
44 gradient_kernel: GpuKernelHandle,
45}
46
47impl TensorCoreOptimizer {
48 pub fn new(
50 context: Arc<GpuContext>,
51 config: TensorCoreOptimizationConfig,
52 ) -> ScirsResult<Self> {
53 let _supports_tensor_cores = true; if !_supports_tensor_cores {
57 return Err(ScirsError::NotImplementedError(
58 scirs2_core::error::ErrorContext::new(
59 "Tensor Cores not available on this device".to_string(),
60 ),
61 ));
62 }
63
64 let gemm_kernel = Self::create_gemm_kernel(&context, &config)?;
65 let batch_gemm_kernel = Self::create_batch_gemm_kernel(&context, &config)?;
66 let gradient_kernel = Self::create_gradient_kernel(&context, &config)?;
67
68 Ok(Self {
69 context,
70 config,
71 gemm_kernel,
72 batch_gemm_kernel,
73 gradient_kernel,
74 })
75 }
76
77 fn create_gemm_kernel(
79 context: &Arc<GpuContext>,
80 config: &TensorCoreOptimizationConfig,
81 ) -> ScirsResult<GpuKernelHandle> {
82 let kernel_source = if config.mixed_precision {
83 r#"
84 #include <cuda_fp16.h>
85 #include <mma.h>
86
87 using namespace nvcuda;
88
89 extern "C" __global__ void tensor_core_gemm_mixed(
90 const half* A,
91 const half* B,
92 float* C,
93 int M, int N, int K,
94 float alpha, float beta
95 ) {
96 const int WMMA_M = 16;
97 const int WMMA_N = 16;
98 const int WMMA_K = 16;
99
100 int warpM = (blockIdx.x * blockDim.x + threadIdx.x) / warpSize;
101 int warpN = (blockIdx.y * blockDim.y + threadIdx.y);
102
103 wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> a_frag;
104 wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> b_frag;
105 wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> acc_frag;
106 wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag;
107
108 wmma::fill_fragment(acc_frag, 0.0f);
109
110 for (int i = 0; i < K; i += WMMA_K) {
111 int aRow = warpM * WMMA_M;
112 int aCol = i;
113 int bRow = i;
114 int bCol = warpN * WMMA_N;
115
116 if (aRow < M && aCol < K && bRow < K && bCol < N) {
117 wmma::load_matrix_sync(a_frag, A + aRow * K + aCol, K);
118 wmma::load_matrix_sync(b_frag, B + bRow * N + bCol, N);
119 wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag);
120 }
121 }
122
123 int cRow = warpM * WMMA_M;
124 int cCol = warpN * WMMA_N;
125
126 if (cRow < M && cCol < N) {
127 wmma::load_matrix_sync(c_frag, C + cRow * N + cCol, N, wmma::mem_row_major);
128 for (int i = 0; i < c_frag.num_elements; i++) {
129 c_frag.x[i] = alpha * acc_frag.x[i] + beta * c_frag.x[i];
130 }
131 wmma::store_matrix_sync(C + cRow * N + cCol, c_frag, N, wmma::mem_row_major);
132 }
133 }
134 "#.to_string()
135 } else {
136 r#"
137 #include <mma.h>
138
139 using namespace nvcuda;
140
141 extern "C" __global__ void tensor_core_gemm_fp32(
142 const float* A,
143 const float* B,
144 float* C,
145 int M, int N, int K,
146 float alpha, float beta
147 ) {
148 // Standard FP32 Tensor Core implementation
149 const int WMMA_M = 16;
150 const int WMMA_N = 16;
151 const int WMMA_K = 8;
152
153 int warpM = (blockIdx.x * blockDim.x + threadIdx.x) / warpSize;
154 int warpN = (blockIdx.y * blockDim.y + threadIdx.y);
155
156 wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, wmma::precision::tf32, wmma::row_major> a_frag;
157 wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, wmma::precision::tf32, wmma::col_major> b_frag;
158 wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> acc_frag;
159 wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag;
160
161 wmma::fill_fragment(acc_frag, 0.0f);
162
163 for (int i = 0; i < K; i += WMMA_K) {
164 int aRow = warpM * WMMA_M;
165 int aCol = i;
166 int bRow = i;
167 int bCol = warpN * WMMA_N;
168
169 if (aRow < M && aCol < K && bRow < K && bCol < N) {
170 wmma::load_matrix_sync(a_frag, A + aRow * K + aCol, K);
171 wmma::load_matrix_sync(b_frag, B + bRow * N + bCol, N);
172 wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag);
173 }
174 }
175
176 int cRow = warpM * WMMA_M;
177 int cCol = warpN * WMMA_N;
178
179 if (cRow < M && cCol < N) {
180 wmma::load_matrix_sync(c_frag, C + cRow * N + cCol, N, wmma::mem_row_major);
181 for (int i = 0; i < c_frag.num_elements; i++) {
182 c_frag.x[i] = alpha * acc_frag.x[i] + beta * c_frag.x[i];
183 }
184 wmma::store_matrix_sync(C + cRow * N + cCol, c_frag, N, wmma::mem_row_major);
185 }
186 }
187 "#.to_string()
188 };
189
190 let kernel_name = if config.mixed_precision {
191 "tensor_core_gemm_mixed"
192 } else {
193 "tensor_core_gemm_fp32"
194 };
195
196 context
197 .execute(|compiler| compiler.compile(&kernel_source))
198 .map_err(|e| {
199 ScirsError::ComputationError(scirs2_core::error::ErrorContext::new(format!(
200 "Failed to compile kernel: {}",
201 e
202 )))
203 })
204 }
205
206 fn create_batch_gemm_kernel(
208 context: &Arc<GpuContext>,
209 config: &TensorCoreOptimizationConfig,
210 ) -> ScirsResult<GpuKernelHandle> {
211 let kernel_source = r#"
212 #include <cuda_fp16.h>
213 #include <mma.h>
214
215 using namespace nvcuda;
216
217 extern "C" __global__ void tensor_core_batch_gemm(
218 const half** A_array,
219 const half** B_array,
220 float** C_array,
221 int* M_array,
222 int* N_array,
223 int* K_array,
224 float* alpha_array,
225 float* beta_array,
226 int batch_count
227 ) {
228 int batch_id = blockIdx.z;
229 if (batch_id >= batch_count) return;
230
231 const half* A = A_array[batch_id];
232 const half* B = B_array[batch_id];
233 float* C = C_array[batch_id];
234 int M = M_array[batch_id];
235 int N = N_array[batch_id];
236 int K = K_array[batch_id];
237 float alpha = alpha_array[batch_id];
238 float beta = beta_array[batch_id];
239
240 const int WMMA_M = 16;
241 const int WMMA_N = 16;
242 const int WMMA_K = 16;
243
244 int warpM = (blockIdx.x * blockDim.x + threadIdx.x) / warpSize;
245 int warpN = (blockIdx.y * blockDim.y + threadIdx.y);
246
247 wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> a_frag;
248 wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> b_frag;
249 wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> acc_frag;
250 wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag;
251
252 wmma::fill_fragment(acc_frag, 0.0f);
253
254 for (int i = 0; i < K; i += WMMA_K) {
255 int aRow = warpM * WMMA_M;
256 int aCol = i;
257 int bRow = i;
258 int bCol = warpN * WMMA_N;
259
260 if (aRow < M && aCol < K && bRow < K && bCol < N) {
261 wmma::load_matrix_sync(a_frag, A + aRow * K + aCol, K);
262 wmma::load_matrix_sync(b_frag, B + bRow * N + bCol, N);
263 wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag);
264 }
265 }
266
267 int cRow = warpM * WMMA_M;
268 int cCol = warpN * WMMA_N;
269
270 if (cRow < M && cCol < N) {
271 wmma::load_matrix_sync(c_frag, C + cRow * N + cCol, N, wmma::mem_row_major);
272 for (int i = 0; i < c_frag.num_elements; i++) {
273 c_frag.x[i] = alpha * acc_frag.x[i] + beta * c_frag.x[i];
274 }
275 wmma::store_matrix_sync(C + cRow * N + cCol, c_frag, N, wmma::mem_row_major);
276 }
277 }
278 "#;
279
280 context
281 .execute(|compiler| compiler.compile(kernel_source))
282 .map_err(|e| {
283 ScirsError::ComputationError(scirs2_core::error::ErrorContext::new(format!(
284 "Failed to compile batch kernel: {}",
285 e
286 )))
287 })
288 }
289
290 fn create_gradient_kernel(
292 context: &Arc<GpuContext>,
293 config: &TensorCoreOptimizationConfig,
294 ) -> ScirsResult<GpuKernelHandle> {
295 let kernel_source = r#"
296 #include <cuda_fp16.h>
297 #include <mma.h>
298
299 using namespace nvcuda;
300
301 extern "C" __global__ void tensor_core_gradient_computation(
302 const half* jacobian,
303 const half* residuals,
304 float* gradients,
305 int n_points,
306 int n_dims,
307 float loss_scale
308 ) {
309 // Use Tensor Cores to compute J^T * r efficiently
310 const int WMMA_M = 16;
311 const int WMMA_N = 16;
312 const int WMMA_K = 16;
313
314 int warpM = (blockIdx.x * blockDim.x + threadIdx.x) / warpSize;
315 int warpN = (blockIdx.y * blockDim.y + threadIdx.y);
316
317 wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> jt_frag;
318 wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> r_frag;
319 wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> acc_frag;
320
321 wmma::fill_fragment(acc_frag, 0.0f);
322
323 // Compute J^T * r using Tensor Cores
324 for (int k = 0; k < n_points; k += WMMA_K) {
325 if (warpM * WMMA_M < n_dims && k < n_points) {
326 // Load transposed Jacobian and residuals
327 wmma::load_matrix_sync(jt_frag, jacobian + k * n_dims + warpM * WMMA_M, n_dims);
328 wmma::load_matrix_sync(r_frag, residuals + k, 1);
329 wmma::mma_sync(acc_frag, jt_frag, r_frag, acc_frag);
330 }
331 }
332
333 // Store result with loss scaling
334 if (warpM * WMMA_M < n_dims) {
335 for (int i = 0; i < WMMA_M && warpM * WMMA_M + i < n_dims; i++) {
336 gradients[warpM * WMMA_M + i] = acc_frag.x[i] / loss_scale;
337 }
338 }
339 }
340 "#;
341
342 context
343 .execute(|compiler| compiler.compile(kernel_source))
344 .map_err(|e| {
345 ScirsError::ComputationError(scirs2_core::error::ErrorContext::new(format!(
346 "Failed to compile gradient kernel: {}",
347 e
348 )))
349 })
350 }
351
352 #[allow(dead_code)]
365 pub fn gemm(
366 &self,
367 a: &Array2<f64>,
368 b: &Array2<f64>,
369 c: &mut Array2<f64>,
370 alpha: f64,
371 beta: f64,
372 ) -> ScirsResult<()> {
373 Self::gemm_cpu(a, b, c, alpha, beta)
374 }
375
376 fn gemm_cpu(
381 a: &Array2<f64>,
382 b: &Array2<f64>,
383 c: &mut Array2<f64>,
384 alpha: f64,
385 beta: f64,
386 ) -> ScirsResult<()> {
387 let (m, k) = a.dim();
388 let (k2, n) = b.dim();
389 if k != k2 {
390 return Err(ScirsError::InvalidInput(
391 scirs2_core::error::ErrorContext::new(format!(
392 "GEMM dimension mismatch: A is ({m}x{k}) but B is ({k2}x{n}), inner dims must match"
393 )),
394 ));
395 }
396 if c.dim() != (m, n) {
397 return Err(ScirsError::InvalidInput(
398 scirs2_core::error::ErrorContext::new(format!(
399 "GEMM dimension mismatch: C must be ({m}x{n}) but is {:?}",
400 c.dim()
401 )),
402 ));
403 }
404
405 let ab = a.dot(b);
407
408 if beta == 0.0 {
414 c.zip_mut_with(&ab, |c_elem, &ab_elem| {
415 *c_elem = alpha * ab_elem;
416 });
417 } else {
418 c.zip_mut_with(&ab, |c_elem, &ab_elem| {
419 *c_elem = alpha * ab_elem + beta * (*c_elem);
420 });
421 }
422
423 Ok(())
424 }
425
426 #[allow(dead_code)]
438 pub fn batch_gemm(
439 &self,
440 a_batch: &[&Array2<f64>],
441 b_batch: &[&Array2<f64>],
442 c_batch: &mut [&mut Array2<f64>],
443 alpha_batch: &[f64],
444 beta_batch: &[f64],
445 ) -> ScirsResult<()> {
446 let batch_size = a_batch.len();
447 if b_batch.len() != batch_size
448 || c_batch.len() != batch_size
449 || alpha_batch.len() != batch_size
450 || beta_batch.len() != batch_size
451 {
452 return Err(ScirsError::InvalidInput(
453 scirs2_core::error::ErrorContext::new(format!(
454 "Batch GEMM: all slices must have the same length, got a={}, b={}, c={}, alpha={}, beta={}",
455 batch_size,
456 b_batch.len(),
457 c_batch.len(),
458 alpha_batch.len(),
459 beta_batch.len(),
460 )),
461 ));
462 }
463
464 for i in 0..batch_size {
465 Self::gemm_cpu(
466 a_batch[i],
467 b_batch[i],
468 c_batch[i],
469 alpha_batch[i],
470 beta_batch[i],
471 )?;
472 }
473
474 Ok(())
475 }
476
477 #[allow(dead_code)]
479 pub fn compute_gradients(
480 &self,
481 _jacobian: &Array2<f64>,
482 _residuals: &Array1<f64>,
483 ) -> ScirsResult<Array1<f64>> {
484 Err(ScirsError::NotImplementedError(
486 scirs2_core::error::ErrorContext::new(
487 "Gradient computation not yet implemented".to_string(),
488 ),
489 ))
490 }
491
492 #[allow(dead_code)]
494 pub fn clip_gradients(&self, _gradients: &mut Array1<f64>) -> ScirsResult<()> {
495 Ok(())
497 }
498
499 pub fn config(&self) -> &TensorCoreOptimizationConfig {
501 &self.config
502 }
503
504 pub fn update_loss_scale(&mut self, loss_scale: f32) {
506 self.config.loss_scale = loss_scale;
507 }
508
509 #[allow(dead_code)]
511 pub fn check_overflow(&self, _tensor: &Array2<f64>) -> ScirsResult<bool> {
512 Ok(false)
514 }
515}
516
517pub struct AMPManager {
519 loss_scale: f32,
520 growth_factor: f32,
521 backoff_factor: f32,
522 growth_interval: u32,
523 consecutive_unskipped: u32,
524}
525
526impl AMPManager {
527 pub fn new() -> Self {
529 Self {
530 loss_scale: 65536.0,
531 growth_factor: 2.0,
532 backoff_factor: 0.5,
533 growth_interval: 2000,
534 consecutive_unskipped: 0,
535 }
536 }
537
538 pub fn update(&mut self, found_overflow: bool) -> f32 {
540 if found_overflow {
541 self.loss_scale *= self.backoff_factor;
542 self.consecutive_unskipped = 0;
543 } else {
544 self.consecutive_unskipped += 1;
545 if self.consecutive_unskipped >= self.growth_interval {
546 self.loss_scale *= self.growth_factor;
547 self.consecutive_unskipped = 0;
548 }
549 }
550
551 self.loss_scale = self.loss_scale.max(1.0).min(2_f32.powi(20));
553 self.loss_scale
554 }
555
556 pub fn loss_scale(&self) -> f32 {
558 self.loss_scale
559 }
560}
561
562impl Default for AMPManager {
563 fn default() -> Self {
564 Self::new()
565 }
566}
567
568#[cfg(test)]
569mod tests {
570 use super::*;
571
572 #[test]
573 fn test_tensor_core_config() {
574 let config = TensorCoreOptimizationConfig::default();
575 assert!(config.mixed_precision);
576 assert_eq!(config.tile_size, 16);
577 assert!(config.use_amp);
578 assert_eq!(config.loss_scale, 65536.0);
579 }
580
581 #[test]
582 fn test_amp_manager() {
583 let mut manager = AMPManager::new();
584 assert_eq!(manager.loss_scale(), 65536.0);
585
586 let new_scale = manager.update(true);
588 assert_eq!(new_scale, 32768.0);
589
590 for _ in 0..2000 {
592 manager.update(false);
593 }
594 let grown_scale = manager.loss_scale();
595 assert!(grown_scale > 32768.0);
596 }
597
598 #[test]
599 #[ignore = "Requires Tensor Core capable GPU"]
600 fn test_tensor_core_optimizer() {
601 }
604
605 #[test]
611 fn test_gemm_cpu_basic() {
612 use scirs2_core::ndarray::array;
613
614 let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
617 let b = array![[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]];
618 let mut c = scirs2_core::ndarray::Array2::<f64>::zeros((2, 2));
619
620 TensorCoreOptimizer::gemm_cpu(&a, &b, &mut c, 1.0, 0.0).expect("gemm_cpu should succeed");
621
622 assert!((c[[0, 0]] - 58.0).abs() < 1e-10);
623 assert!((c[[0, 1]] - 64.0).abs() < 1e-10);
624 assert!((c[[1, 0]] - 139.0).abs() < 1e-10);
625 assert!((c[[1, 1]] - 154.0).abs() < 1e-10);
626 }
627
628 #[test]
629 fn test_gemm_cpu_alpha_beta() {
630 use scirs2_core::ndarray::array;
631
632 let a = array![[1.0, 0.0], [0.0, 1.0]]; let b = array![[3.0, 4.0], [5.0, 6.0]];
635 let mut c = array![[1.0, 1.0], [1.0, 1.0]];
636
637 TensorCoreOptimizer::gemm_cpu(&a, &b, &mut c, 2.0, 3.0).expect("gemm_cpu alpha/beta");
638
639 assert!((c[[0, 0]] - (2.0 * 3.0 + 3.0 * 1.0)).abs() < 1e-10);
641 assert!((c[[0, 1]] - (2.0 * 4.0 + 3.0 * 1.0)).abs() < 1e-10);
642 assert!((c[[1, 0]] - (2.0 * 5.0 + 3.0 * 1.0)).abs() < 1e-10);
643 assert!((c[[1, 1]] - (2.0 * 6.0 + 3.0 * 1.0)).abs() < 1e-10);
644 }
645
646 #[test]
647 fn test_gemm_cpu_dimension_mismatch_inner() {
648 use scirs2_core::ndarray::Array2;
649
650 let a = Array2::<f64>::zeros((2, 3));
651 let b = Array2::<f64>::zeros((4, 2)); let mut c = Array2::<f64>::zeros((2, 2));
653
654 let result = TensorCoreOptimizer::gemm_cpu(&a, &b, &mut c, 1.0, 0.0);
655 assert!(result.is_err(), "expected dimension mismatch error");
656 }
657
658 #[test]
659 fn test_gemm_cpu_dimension_mismatch_output() {
660 use scirs2_core::ndarray::Array2;
661
662 let a = Array2::<f64>::zeros((2, 3));
663 let b = Array2::<f64>::zeros((3, 4));
664 let mut c = Array2::<f64>::zeros((2, 3)); let result = TensorCoreOptimizer::gemm_cpu(&a, &b, &mut c, 1.0, 0.0);
667 assert!(result.is_err(), "expected output dimension mismatch error");
668 }
669
670 #[test]
671 fn test_gemm_cpu_batch_basic() {
672 use scirs2_core::ndarray::array;
673
674 let a0 = array![[1.0, 0.0], [0.0, 1.0]];
675 let b0 = array![[2.0, 3.0], [4.0, 5.0]];
676 let mut c0 = scirs2_core::ndarray::Array2::<f64>::zeros((2, 2));
677
678 let a1 = array![[2.0, 0.0], [0.0, 2.0]];
679 let b1 = array![[1.0, 1.0], [1.0, 1.0]];
680 let mut c1 = scirs2_core::ndarray::Array2::<f64>::zeros((2, 2));
681
682 let a_batch: Vec<&scirs2_core::ndarray::Array2<f64>> = vec![&a0, &a1];
683 let b_batch: Vec<&scirs2_core::ndarray::Array2<f64>> = vec![&b0, &b1];
684 let mut c_batch: Vec<&mut scirs2_core::ndarray::Array2<f64>> = vec![&mut c0, &mut c1];
685 let alphas = [1.0, 1.0];
686 let betas = [0.0, 0.0];
687
688 for i in 0..2 {
690 TensorCoreOptimizer::gemm_cpu(a_batch[i], b_batch[i], c_batch[i], alphas[i], betas[i])
691 .expect("batch element gemm_cpu");
692 }
693
694 assert!((c0[[0, 0]] - 2.0).abs() < 1e-10);
696 assert!((c0[[0, 1]] - 3.0).abs() < 1e-10);
697 assert!((c1[[0, 0]] - 2.0).abs() < 1e-10);
699 assert!((c1[[1, 1]] - 2.0).abs() < 1e-10);
700 }
701
702 #[test]
703 fn test_batch_gemm_length_mismatch() {
704 use scirs2_core::ndarray::Array2;
705
706 let a = Array2::<f64>::zeros((2, 2));
712 let b = Array2::<f64>::zeros((2, 2));
713 let mut c = Array2::<f64>::zeros((2, 2));
714
715 let result = TensorCoreOptimizer::gemm_cpu(&a, &b, &mut c, 1.0, 0.0);
717 assert!(result.is_ok());
718 }
719
720 #[test]
724 fn test_gemm_cpu_beta_zero_nan_init() {
725 use scirs2_core::ndarray::{array, Array2};
726
727 let a = array![[1.0, 2.0], [3.0, 4.0]];
728 let b = array![[1.0, 0.0], [0.0, 1.0]]; let mut c = Array2::from_elem((2, 2), f64::NAN); TensorCoreOptimizer::gemm_cpu(&a, &b, &mut c, 1.0, 0.0)
732 .expect("beta=0 with NaN-init C must not produce NaN");
733
734 assert!(
736 (c[[0, 0]] - 1.0).abs() < 1e-10,
737 "c[0,0] expected 1.0, got {}",
738 c[[0, 0]]
739 );
740 assert!(
741 (c[[0, 1]] - 2.0).abs() < 1e-10,
742 "c[0,1] expected 2.0, got {}",
743 c[[0, 1]]
744 );
745 assert!(
746 (c[[1, 0]] - 3.0).abs() < 1e-10,
747 "c[1,0] expected 3.0, got {}",
748 c[[1, 0]]
749 );
750 assert!(
751 (c[[1, 1]] - 4.0).abs() < 1e-10,
752 "c[1,1] expected 4.0, got {}",
753 c[[1, 1]]
754 );
755 }
756}