1pub const VERSION: &str = env!("CARGO_PKG_VERSION");
11pub const VERSION_MAJOR: u32 = 0;
12pub const VERSION_MINOR: u32 = 1;
13pub const VERSION_PATCH: u32 = 0;
14
15use torsh_core::{Result, TorshError};
16use torsh_tensor::Tensor;
17
18pub type TorshResult<T> = Result<T>;
20
21pub mod advanced_ops;
22pub mod comparison;
23pub mod decomposition;
24pub mod matrix_functions;
25pub mod numerical_stability;
26pub mod perf;
27pub mod randomized;
28pub mod solve;
29pub mod solvers;
30pub mod sparse;
31pub mod special_matrices;
32pub mod taylor;
33pub mod utils;
34
35#[cfg(feature = "scirs2-integration")]
37pub mod attention;
38#[cfg(feature = "scirs2-integration")]
39pub mod matrix_calculus;
40#[cfg(feature = "scirs2-integration")]
41pub mod matrix_equations;
42#[cfg(feature = "scirs2-integration")]
43pub mod quantization;
44
45#[cfg(feature = "scirs2-integration")]
47pub mod scirs2_linalg_integration;
48
49pub use advanced_ops::*;
51pub use comparison::*;
52pub use decomposition::*;
53pub use matrix_functions::*;
54pub use numerical_stability::{
56 check_numerical_stability, equilibrate_matrix, unequilibrate_solution, EquilibrationStrategy,
57 ScalingFactors, StabilityConfig,
58};
59pub use randomized::*;
60pub use solvers::*;
63pub use sparse::*;
64pub use special_matrices::*;
65pub use taylor::*;
66pub use utils::*;
67
68#[cfg(feature = "scirs2-integration")]
70pub use scirs2_linalg_integration::*;
71
72pub(crate) fn validate_square_matrix(tensor: &Tensor, operation: &str) -> TorshResult<usize> {
74 if tensor.shape().ndim() != 2 {
75 return Err(TorshError::InvalidArgument(format!(
76 "{} requires a 2D tensor, got {}D tensor",
77 operation,
78 tensor.shape().ndim()
79 )));
80 }
81
82 let (rows, cols) = (tensor.shape().dims()[0], tensor.shape().dims()[1]);
83 if rows != cols {
84 return Err(TorshError::InvalidArgument(format!(
85 "{operation} requires a square matrix, got {rows}x{cols} matrix"
86 )));
87 }
88
89 Ok(rows)
90}
91
92#[allow(dead_code)]
94fn validate_matrix_dimensions(a: &Tensor, b: &Tensor, operation: &str) -> TorshResult<()> {
95 if a.shape().ndim() != 2 || b.shape().ndim() != 2 {
96 return Err(TorshError::InvalidArgument(format!(
97 "{operation} requires 2D tensors, got {}D and {}D tensors",
98 a.shape().ndim(),
99 b.shape().ndim()
100 )));
101 }
102
103 let a_shape = a.shape();
104 let b_shape = b.shape();
105 let a_dims = a_shape.dims();
106 let b_dims = b_shape.dims();
107
108 if a_dims[0] != b_dims[0] {
109 return Err(TorshError::InvalidArgument(format!(
110 "{operation} requires compatible dimensions, got {}x{} and {}x{}",
111 a_dims[0], a_dims[1], b_dims[0], b_dims[1]
112 )));
113 }
114
115 Ok(())
116}
117
118fn vector_norm_2(tensor: &Tensor) -> TorshResult<f32> {
120 let n = tensor.shape().dims()[0];
121 let mut sum = 0.0f32;
122
123 for i in 0..n {
124 let val = tensor.get(&[i])?;
125 sum += val * val;
126 }
127
128 Ok(sum.sqrt())
129}
130
131fn vector_inner_product(a: &Tensor, b: &Tensor) -> TorshResult<f32> {
133 let n = a.shape().dims()[0];
134 let mut sum = 0.0f32;
135
136 for i in 0..n {
137 sum += a.get(&[i])? * b.get(&[i])?;
138 }
139
140 Ok(sum)
141}
142
143fn get_relative_tolerance(tensor: &Tensor, default_tol: f32) -> TorshResult<f32> {
145 let (m, n) = (tensor.shape().dims()[0], tensor.shape().dims()[1]);
147 let mut max_abs = 0.0f32;
148
149 for i in 0..m {
150 for j in 0..n {
151 let val = tensor.get(&[i, j])?.abs();
152 if val > max_abs {
153 max_abs = val;
154 }
155 }
156 }
157
158 Ok((max_abs * default_tol).max(1e-12))
160}
161
162pub fn det(tensor: &Tensor) -> TorshResult<f32> {
164 let n = validate_square_matrix(tensor, "Determinant computation")?;
165
166 match n {
168 1 => tensor.get(&[0, 0]),
169 2 => {
170 let a = tensor.get(&[0, 0])?;
171 let b = tensor.get(&[0, 1])?;
172 let c = tensor.get(&[1, 0])?;
173 let d = tensor.get(&[1, 1])?;
174 Ok(a * d - b * c)
175 }
176 _ => {
177 let (_, _, u) = lu(tensor)?;
179
180 let mut det = 1.0;
182 for i in 0..n {
183 det *= u.get(&[i, i])?;
184 }
185 Ok(det)
186 }
187 }
188}
189
190pub fn matrix_rank(tensor: &Tensor, tol: Option<f32>) -> TorshResult<usize> {
192 if tensor.shape().ndim() != 2 {
193 return Err(TorshError::InvalidArgument(format!(
194 "Matrix rank computation requires a 2D tensor, got {}D tensor",
195 tensor.shape().ndim()
196 )));
197 }
198
199 let (_, s, _) = svd(tensor, false)?;
201
202 let tol = if let Some(user_tol) = tol {
204 user_tol
205 } else {
206 let max_sv = s.get(&[0])?; (max_sv * 1e-6).max(1e-12) };
210
211 let s_len = s.shape().dims()[0];
213
214 let mut rank = 0;
216 for i in 0..s_len {
217 let singular_value = s.get(&[i])?;
218 if singular_value.abs() > tol {
219 rank += 1;
220 }
221 }
222
223 Ok(rank)
224}
225
226pub fn trace(tensor: &Tensor) -> TorshResult<f32> {
228 if tensor.shape().ndim() != 2 {
229 return Err(TorshError::InvalidArgument(format!(
230 "Trace computation requires a 2D tensor, got {}D tensor",
231 tensor.shape().ndim()
232 )));
233 }
234
235 let (rows, cols) = (tensor.shape().dims()[0], tensor.shape().dims()[1]);
236 let size = rows.min(cols);
237
238 let mut sum = 0.0;
239 for i in 0..size {
240 sum += tensor.get(&[i, i])?;
241 }
242
243 Ok(sum)
244}
245
246pub fn matmul(a: &Tensor, b: &Tensor) -> TorshResult<Tensor> {
248 if a.shape().ndim() < 2 || b.shape().ndim() < 2 {
249 return Err(TorshError::InvalidArgument(format!(
250 "Matrix multiplication requires at least 2D tensors, got {}D and {}D tensors",
251 a.shape().ndim(),
252 b.shape().ndim()
253 )));
254 }
255
256 let a_shape = a.shape();
258 let b_shape = b.shape();
259 let a_dims = a_shape.dims();
260 let b_dims = b_shape.dims();
261
262 let a_rows = a_dims[a_dims.len() - 2];
263 let a_cols = a_dims[a_dims.len() - 1];
264 let b_rows = b_dims[b_dims.len() - 2];
265 let b_cols = b_dims[b_dims.len() - 1];
266
267 if a_cols != b_rows {
268 return Err(TorshError::InvalidArgument(
269 format!("Incompatible dimensions for matrix multiplication: {a_rows}x{a_cols} and {b_rows}x{b_cols}")
270 ));
271 }
272
273 a.matmul(b)
276}
277
278pub fn matvec(matrix: &Tensor, vector: &Tensor) -> TorshResult<Tensor> {
280 if matrix.shape().ndim() != 2 {
281 return Err(TorshError::InvalidArgument(
282 "matvec requires 2D matrix".to_string(),
283 ));
284 }
285
286 if vector.shape().ndim() != 1 {
287 return Err(TorshError::InvalidArgument(
288 "matvec requires 1D vector".to_string(),
289 ));
290 }
291
292 let (m, n) = (matrix.shape().dims()[0], matrix.shape().dims()[1]);
293 let vec_len = vector.shape().dims()[0];
294
295 if n != vec_len {
296 return Err(TorshError::InvalidArgument(format!(
297 "Incompatible dimensions for matrix-vector multiplication: {m}x{n} and {vec_len}"
298 )));
299 }
300
301 let mut result_data = vec![0.0f32; m];
303 for (i, result_item) in result_data.iter_mut().enumerate().take(m) {
304 let mut sum = 0.0;
305 for j in 0..n {
306 sum += matrix.get(&[i, j])? * vector.get(&[j])?;
307 }
308 *result_item = sum;
309 }
310
311 Tensor::from_data(result_data, vec![m], matrix.device())
312}
313
314pub fn vecmat(vector: &Tensor, matrix: &Tensor) -> TorshResult<Tensor> {
316 if vector.shape().ndim() != 1 {
317 return Err(TorshError::InvalidArgument(
318 "vecmat requires 1D vector".to_string(),
319 ));
320 }
321
322 if matrix.shape().ndim() != 2 {
323 return Err(TorshError::InvalidArgument(
324 "vecmat requires 2D matrix".to_string(),
325 ));
326 }
327
328 let vec_len = vector.shape().dims()[0];
329 let (m, n) = (matrix.shape().dims()[0], matrix.shape().dims()[1]);
330
331 if vec_len != m {
332 return Err(TorshError::InvalidArgument(format!(
333 "Incompatible dimensions for vector-matrix multiplication: {vec_len} and {m}x{n}"
334 )));
335 }
336
337 let mut result_data = vec![0.0f32; n];
339 for (j, result_item) in result_data.iter_mut().enumerate().take(n) {
340 let mut sum = 0.0;
341 for i in 0..m {
342 sum += vector.get(&[i])? * matrix.get(&[i, j])?;
343 }
344 *result_item = sum;
345 }
346
347 Tensor::from_data(result_data, vec![n], vector.device())
348}
349
350pub fn cond(tensor: &Tensor, p: Option<&str>) -> TorshResult<f32> {
352 validate_square_matrix(tensor, "Condition number computation")?;
353
354 let p = p.unwrap_or("2");
355
356 match p {
357 "2" => {
358 let (_, s, _) = decomposition::svd(tensor, false)?;
360
361 let s_shape = s.shape();
363 let s_dims = s_shape.dims();
364 let min_dim = s_dims[0];
365
366 if min_dim == 0 {
367 return Ok(f32::INFINITY);
368 }
369
370 let mut max_sv = 0.0f32;
371 let mut min_sv = f32::INFINITY;
372
373 for i in 0..min_dim {
374 let sv = s.get(&[i])?;
375 if sv > max_sv {
376 max_sv = sv;
377 }
378 if sv < min_sv && sv > 1e-12 {
379 min_sv = sv;
380 }
381 }
382
383 if min_sv == f32::INFINITY || min_sv < 1e-12 {
384 Ok(f32::INFINITY)
385 } else {
386 Ok(max_sv / min_sv)
387 }
388 }
389 "1" => {
390 let norm_a = matrix_functions::matrix_norm(tensor, Some("1"))?;
392 let a_inv = crate::solvers::inv(tensor)?;
393 let norm_a_inv = matrix_functions::matrix_norm(&a_inv, Some("1"))?;
394 Ok(norm_a * norm_a_inv)
395 }
396 "inf" => {
397 let norm_a = matrix_functions::matrix_norm(tensor, Some("inf"))?;
399 let a_inv = crate::solvers::inv(tensor)?;
400 let norm_a_inv = matrix_functions::matrix_norm(&a_inv, Some("inf"))?;
401 Ok(norm_a * norm_a_inv)
402 }
403 "fro" => {
404 let norm_a = matrix_functions::matrix_norm(tensor, Some("fro"))?;
406 let a_inv = crate::solvers::inv(tensor)?;
407 let norm_a_inv = matrix_functions::matrix_norm(&a_inv, Some("fro"))?;
408 Ok(norm_a * norm_a_inv)
409 }
410 _ => Err(TorshError::InvalidArgument(format!(
411 "Unknown norm type for condition number: {p}"
412 ))),
413 }
414}
415
416pub fn cond_estimate(
421 tensor: &Tensor,
422 p: Option<&str>,
423 max_iter: Option<usize>,
424) -> TorshResult<f32> {
425 if tensor.shape().ndim() != 2 {
426 return Err(TorshError::InvalidArgument(
427 "Condition number estimation requires 2D tensor".to_string(),
428 ));
429 }
430
431 let (m, n) = (tensor.shape().dims()[0], tensor.shape().dims()[1]);
432 if m != n {
433 return Err(TorshError::InvalidArgument(
434 "Condition number estimation requires square matrix".to_string(),
435 ));
436 }
437
438 let p = p.unwrap_or("2");
439 let max_iter = max_iter.unwrap_or(100);
440 let tolerance = get_relative_tolerance(tensor, 1e-6)?;
442
443 match p {
444 "2" => {
445 let at = tensor.t()?;
451 let ata = at.matmul(tensor)?;
452
453 let v = torsh_tensor::creation::zeros::<f32>(&[n])?;
454 for i in 0..n {
455 v.set(&[i], (1.0 + i as f32 * 0.1).sin())?;
456 }
457
458 let mut max_eigenvalue = 0.0f32;
459 for _ in 0..max_iter {
460 let av = ata.matmul(&v.unsqueeze(1)?)?;
461 let av = av.squeeze(1)?;
462
463 let numerator = vector_inner_product(&v, &av)?;
465 let denominator = vector_inner_product(&v, &v)?;
466
467 let new_eigenvalue = if denominator > tolerance {
468 numerator / denominator
469 } else {
470 0.0
471 };
472
473 if (new_eigenvalue - max_eigenvalue).abs() < tolerance {
474 max_eigenvalue = new_eigenvalue;
475 break;
476 }
477 max_eigenvalue = new_eigenvalue;
478
479 let norm = vector_norm_2(&av)?;
481
482 if norm < tolerance {
483 break;
484 }
485
486 for i in 0..n {
487 v.set(&[i], av.get(&[i])? / norm)?;
488 }
489 }
490
491 let max_singular_value = max_eigenvalue.sqrt();
492
493 let mut min_singular_value = if max_singular_value > tolerance {
496 let det_val = det(tensor)?;
498 let matrix_norm = matrix_functions::matrix_norm(tensor, Some("fro"))?;
499
500 if det_val.abs() > tolerance && matrix_norm > tolerance {
501 det_val.abs() / matrix_norm.powi(n as i32 - 1)
502 } else {
503 tolerance }
505 } else {
506 tolerance
507 };
508
509 if min_singular_value > tolerance {
511 let inv_ata = crate::solvers::inv(&ata)?;
512 let v_min = torsh_tensor::creation::zeros::<f32>(&[n])?;
513 for i in 0..n {
514 v_min.set(&[i], (1.0 + i as f32 * 0.3).cos())?;
515 }
516
517 for _ in 0..5 {
518 let av = inv_ata.matmul(&v_min.unsqueeze(1)?)?;
520 let av = av.squeeze(1)?;
521
522 let mut norm = 0.0f32;
523 for i in 0..n {
524 let val = av.get(&[i])?;
525 norm += val * val;
526 }
527 norm = norm.sqrt();
528
529 if norm < tolerance {
530 break;
531 }
532
533 for i in 0..n {
534 v_min.set(&[i], av.get(&[i])? / norm)?;
535 }
536
537 min_singular_value = (1.0 / norm).sqrt();
539 }
540 }
541
542 if min_singular_value < tolerance {
543 Ok(f32::INFINITY)
544 } else {
545 Ok(max_singular_value / min_singular_value)
546 }
547 }
548 "1" => {
549 let norm_a = matrix_functions::matrix_norm(tensor, Some("1"))?;
551
552 let inv_a = crate::solvers::inv(tensor)?;
555 let norm_inv_a = matrix_functions::matrix_norm(&inv_a, Some("1"))?;
556
557 Ok(norm_a * norm_inv_a)
558 }
559 "inf" => {
560 let norm_a = matrix_functions::matrix_norm(tensor, Some("inf"))?;
562 let inv_a = crate::solvers::inv(tensor)?;
563 let norm_inv_a = matrix_functions::matrix_norm(&inv_a, Some("inf"))?;
564
565 Ok(norm_a * norm_inv_a)
566 }
567 _ => Err(TorshError::InvalidArgument(format!(
568 "Unknown norm type for condition estimation: {p}"
569 ))),
570 }
571}
572
573pub fn stability_analysis(tensor: &Tensor) -> TorshResult<(f32, usize, usize, f32)> {
579 if tensor.shape().ndim() != 2 {
580 return Err(TorshError::InvalidArgument(
581 "Stability analysis requires 2D tensor".to_string(),
582 ));
583 }
584
585 let (m, n) = (tensor.shape().dims()[0], tensor.shape().dims()[1]);
586
587 let condition_number = cond_estimate(tensor, Some("2"), Some(50))?;
589
590 let rank_strict = matrix_rank(tensor, Some(1e-12))?;
592 let rank_numerical = matrix_rank(tensor, Some(1e-8))?;
593
594 let (_, s, _) = decomposition::svd(tensor, false)?;
596 let min_dim = m.min(n);
597
598 let mut stability_metric = 0.0f32;
599 if min_dim > 1 {
600 let largest_sv = s.get(&[0])?;
601 let second_largest_sv = if min_dim > 1 {
602 s.get(&[1])?
603 } else {
604 largest_sv
605 };
606
607 if second_largest_sv > 1e-12 && largest_sv > 1e-12 {
608 stability_metric = second_largest_sv / largest_sv;
609 }
610 }
611
612 Ok((
613 condition_number,
614 rank_strict,
615 rank_numerical,
616 stability_metric,
617 ))
618}
619
620pub fn einsum(subscripts: &str, operands: &[&Tensor]) -> TorshResult<Tensor> {
624 if operands.is_empty() {
625 return Err(TorshError::InvalidArgument(
626 "einsum requires at least one operand".to_string(),
627 ));
628 }
629
630 let parts: Vec<&str> = subscripts.split("->").collect();
632 if parts.len() != 2 {
633 return Err(TorshError::InvalidArgument(
634 "einsum subscripts must contain '->' separator".to_string(),
635 ));
636 }
637
638 let input_subscripts = parts[0];
639 let output_subscript = parts[1];
640
641 match (input_subscripts, output_subscript) {
643 ("ij,jk", "ik") => {
645 if operands.len() != 2 {
646 return Err(TorshError::InvalidArgument(
647 "Matrix multiplication requires exactly 2 operands".to_string(),
648 ));
649 }
650 matmul(operands[0], operands[1])
651 }
652 ("ii", "i") => {
654 if operands.len() != 1 {
655 return Err(TorshError::InvalidArgument(
656 "Diagonal extraction requires exactly 1 operand".to_string(),
657 ));
658 }
659 special_matrices::diag(operands[0], 0)
660 }
661 ("ii", "") => {
663 if operands.len() != 1 {
664 return Err(TorshError::InvalidArgument(
665 "Trace requires exactly 1 operand".to_string(),
666 ));
667 }
668 let trace_val = trace(operands[0])?;
669 Tensor::from_data(vec![trace_val], vec![], operands[0].device())
670 }
671 ("ij", "ji") => {
673 if operands.len() != 1 {
674 return Err(TorshError::InvalidArgument(
675 "Transpose requires exactly 1 operand".to_string(),
676 ));
677 }
678 operands[0].transpose(-2, -1)
679 }
680 ("bij,bjk", "bik") => {
682 if operands.len() != 2 {
683 return Err(TorshError::InvalidArgument(
684 "Batch matrix multiplication requires exactly 2 operands".to_string(),
685 ));
686 }
687 matmul(operands[0], operands[1])
689 }
690 ("i,j", "ij") => {
692 if operands.len() != 2 {
693 return Err(TorshError::InvalidArgument(
694 "Outer product requires exactly 2 operands".to_string(),
695 ));
696 }
697 outer(operands[0], operands[1])
698 }
699 ("i,i", "") => {
701 if operands.len() != 2 {
702 return Err(TorshError::InvalidArgument(
703 "Inner product requires exactly 2 operands".to_string(),
704 ));
705 }
706 inner(operands[0], operands[1])
707 }
708 _ => Err(TorshError::InvalidArgument(format!(
709 "Unsupported einsum pattern: {input_subscripts} -> {output_subscript}"
710 ))),
711 }
712}
713
714pub fn outer(a: &Tensor, b: &Tensor) -> TorshResult<Tensor> {
716 if a.shape().ndim() != 1 || b.shape().ndim() != 1 {
717 return Err(TorshError::InvalidArgument(
718 "Outer product requires 1D tensors".to_string(),
719 ));
720 }
721
722 let a_len = a.shape().dims()[0];
723 let b_len = b.shape().dims()[0];
724
725 let mut result_data = vec![0.0f32; a_len * b_len];
726 for i in 0..a_len {
727 let a_val = a.get(&[i])?; for j in 0..b_len {
729 result_data[i * b_len + j] = a_val * b.get(&[j])?;
730 }
731 }
732
733 Tensor::from_data(result_data, vec![a_len, b_len], a.device())
734}
735
736pub fn inner(a: &Tensor, b: &Tensor) -> TorshResult<Tensor> {
738 if a.shape().ndim() != 1 || b.shape().ndim() != 1 {
739 return Err(TorshError::InvalidArgument(
740 "Inner product requires 1D tensors".to_string(),
741 ));
742 }
743
744 let a_len = a.shape().dims()[0];
745 let b_len = b.shape().dims()[0];
746
747 if a_len != b_len {
748 return Err(TorshError::InvalidArgument(
749 "Inner product requires vectors of the same length".to_string(),
750 ));
751 }
752
753 let mut sum = 0.0f32;
754 for i in 0..a_len {
755 sum += a.get(&[i])? * b.get(&[i])?;
756 }
757
758 Tensor::from_data(vec![sum], vec![], a.device())
759}
760
761#[derive(Debug, Clone)]
763pub struct MatrixAnalysis {
764 pub dimensions: (usize, usize),
766 pub is_square: bool,
768 pub is_symmetric: bool,
770 pub is_positive_definite: bool,
772 pub is_diagonal: bool,
774 pub is_identity: bool,
776 pub determinant: Option<f32>,
778 pub trace: Option<f32>,
780 pub rank: usize,
782 pub condition_number: Option<f32>,
784 pub norms: (f32, f32, f32),
786 pub value_range: (f32, f32),
788 pub sparsity: f32,
790 pub recommended_solver: String,
792 pub stability_assessment: String,
794}
795
796pub fn analyze_matrix(tensor: &Tensor) -> TorshResult<MatrixAnalysis> {
801 if tensor.shape().ndim() != 2 {
802 return Err(TorshError::InvalidArgument(
803 "Matrix analysis requires 2D tensor".to_string(),
804 ));
805 }
806
807 let (m, n) = (tensor.shape().dims()[0], tensor.shape().dims()[1]);
808 let is_square = m == n;
809 let tolerance = 1e-6f32;
810
811 let mut is_symmetric = false;
813 if is_square {
814 is_symmetric = true;
815 for i in 0..m {
816 for j in 0..n {
817 if (tensor.get(&[i, j])? - tensor.get(&[j, i])?).abs() > tolerance {
818 is_symmetric = false;
819 break;
820 }
821 }
822 if !is_symmetric {
823 break;
824 }
825 }
826 }
827
828 let mut is_diagonal = true;
830 let mut diagonal_values = Vec::new();
831 for i in 0..m {
832 for j in 0..n {
833 let val = tensor.get(&[i, j])?;
834 if i == j {
835 diagonal_values.push(val);
836 } else if val.abs() > tolerance {
837 is_diagonal = false;
838 }
839 }
840 }
841
842 let mut is_identity = is_diagonal && is_square;
844 if is_identity {
845 for &diag_val in &diagonal_values {
846 if (diag_val - 1.0).abs() > tolerance {
847 is_identity = false;
848 break;
849 }
850 }
851 }
852
853 let mut is_positive_definite = false;
855 if is_symmetric && is_square {
856 is_positive_definite = diagonal_values.iter().all(|&val| val > 0.0);
857 if is_positive_definite {
859 is_positive_definite = decomposition::cholesky(tensor, false).is_ok();
860 }
861 }
862
863 let determinant = if is_square {
865 Some(det(tensor).unwrap_or(0.0))
866 } else {
867 None
868 };
869
870 let trace_val = if is_square {
871 Some(trace(tensor).unwrap_or(0.0))
872 } else {
873 None
874 };
875
876 let rank = matrix_rank(tensor, None).unwrap_or(m.min(n));
877
878 let condition_number = if is_square {
879 cond_estimate(tensor, Some("2"), Some(50)).ok()
880 } else {
881 None
882 };
883
884 let fro_norm = matrix_functions::matrix_norm(tensor, Some("fro")).unwrap_or(0.0);
886 let one_norm = matrix_functions::matrix_norm(tensor, Some("1")).unwrap_or(0.0);
887 let inf_norm = matrix_functions::matrix_norm(tensor, Some("inf")).unwrap_or(0.0);
888 let norms = (fro_norm, one_norm, inf_norm);
889
890 let mut min_abs = f32::INFINITY;
892 let mut max_abs = 0.0f32;
893 let mut zero_count = 0;
894 let total_elements = m * n;
895
896 for i in 0..m {
897 for j in 0..n {
898 let val = tensor.get(&[i, j]).unwrap_or(0.0);
899 let abs_val = val.abs();
900 if abs_val < tolerance {
901 zero_count += 1;
902 } else {
903 min_abs = min_abs.min(abs_val);
904 }
905 max_abs = max_abs.max(abs_val);
906 }
907 }
908
909 let value_range = (min_abs, max_abs);
910 let sparsity = zero_count as f32 / total_elements as f32;
911
912 let recommended_solver = if is_identity {
914 "Identity matrix: use trivial solver".to_string()
915 } else if is_diagonal {
916 "Diagonal matrix: use diagonal solver".to_string()
917 } else if is_positive_definite {
918 "Positive definite: use Cholesky decomposition".to_string()
919 } else if is_symmetric {
920 "Symmetric matrix: use symmetric solver (LDLT)".to_string()
921 } else if sparsity > 0.5 {
922 "Sparse matrix: use sparse iterative solvers (CG, GMRES, BiCGSTAB)".to_string()
923 } else if let Some(cond_num) = condition_number {
924 if cond_num < 100.0 {
925 "Well-conditioned: use LU decomposition".to_string()
926 } else if cond_num < 1e6 {
927 "Moderately conditioned: use LU with iterative refinement".to_string()
928 } else {
929 "Ill-conditioned: use regularization or specialized methods".to_string()
930 }
931 } else if m > n {
932 "Overdetermined system: use QR decomposition or least squares".to_string()
933 } else if m < n {
934 "Underdetermined system: use minimum norm solution".to_string()
935 } else {
936 "General square matrix: use LU decomposition".to_string()
937 };
938
939 let stability_assessment = if is_identity {
941 "Excellent: Identity matrix is perfectly conditioned".to_string()
942 } else if let Some(cond_num) = condition_number {
943 if cond_num < 10.0 {
944 "Excellent: Very well-conditioned matrix".to_string()
945 } else if cond_num < 100.0 {
946 "Good: Well-conditioned matrix".to_string()
947 } else if cond_num < 1e6 {
948 "Moderate: Reasonable conditioning, monitor for accuracy".to_string()
949 } else if cond_num < 1e12 {
950 format!("Poor: Ill-conditioned (κ ≈ {cond_num:.2e}), expect numerical issues")
951 } else {
952 "Critical: Severely ill-conditioned, results may be unreliable".to_string()
953 }
954 } else if rank < m.min(n) {
955 "Poor: Rank-deficient matrix, singular or near-singular".to_string()
956 } else {
957 "Unknown: Unable to assess conditioning for non-square matrix".to_string()
958 };
959
960 Ok(MatrixAnalysis {
961 dimensions: (m, n),
962 is_square,
963 is_symmetric,
964 is_positive_definite,
965 is_diagonal,
966 is_identity,
967 determinant,
968 trace: trace_val,
969 rank,
970 condition_number,
971 norms,
972 value_range,
973 sparsity,
974 recommended_solver,
975 stability_assessment,
976 })
977}
978
979pub fn bmm(batch1: &Tensor, batch2: &Tensor) -> TorshResult<Tensor> {
981 if batch1.shape().ndim() != 3 || batch2.shape().ndim() != 3 {
982 return Err(TorshError::InvalidArgument(
983 "Batch matrix multiplication requires 3D tensors".to_string(),
984 ));
985 }
986
987 let batch1_shape = batch1.shape();
988 let batch2_shape = batch2.shape();
989 let batch1_dims = batch1_shape.dims();
990 let batch2_dims = batch2_shape.dims();
991
992 let (b1, _n1, k1) = (batch1_dims[0], batch1_dims[1], batch1_dims[2]);
993 let (b2, k2, _m2) = (batch2_dims[0], batch2_dims[1], batch2_dims[2]);
994
995 if b1 != b2 {
996 return Err(TorshError::InvalidArgument(
997 "Batch sizes must match for batch matrix multiplication".to_string(),
998 ));
999 }
1000
1001 if k1 != k2 {
1002 return Err(TorshError::InvalidArgument(
1003 "Inner dimensions must match for batch matrix multiplication".to_string(),
1004 ));
1005 }
1006
1007 batch1.matmul(batch2)
1009}
1010
1011#[cfg(test)]
1012mod tests {
1013 use super::*;
1014 use approx::assert_relative_eq;
1015 use torsh_tensor::creation::eye;
1016
1017 fn create_test_matrix_2x2() -> TorshResult<Tensor> {
1018 let data = vec![1.0f32, 2.0, 3.0, 4.0];
1020 Tensor::from_data(data, vec![2, 2], torsh_core::DeviceType::Cpu)
1021 }
1022
1023 fn create_test_matrix_3x3() -> TorshResult<Tensor> {
1024 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
1026 Tensor::from_data(data, vec![3, 3], torsh_core::DeviceType::Cpu)
1027 }
1028
1029 fn create_test_vector() -> TorshResult<Tensor> {
1030 let data = vec![1.0f32, 2.0, 3.0];
1032 Tensor::from_data(data, vec![3], torsh_core::DeviceType::Cpu)
1033 }
1034
1035 #[test]
1036 fn test_determinant() -> TorshResult<()> {
1037 let mat = create_test_matrix_2x2()?;
1039 let det_val = det(&mat)?;
1040
1041 assert_relative_eq!(det_val, -2.0, epsilon = 1e-6);
1043
1044 let identity = eye::<f32>(3)?;
1046 let det_identity = det(&identity)?;
1047 assert_relative_eq!(det_identity, 1.0, epsilon = 1e-6);
1048
1049 Ok(())
1050 }
1051
1052 #[test]
1053 fn test_trace() -> TorshResult<()> {
1054 let mat = create_test_matrix_3x3()?;
1055 let trace_val = trace(&mat)?;
1056
1057 assert_relative_eq!(trace_val, 15.0, epsilon = 1e-6);
1059
1060 let identity = eye::<f32>(4)?;
1062 let trace_identity = trace(&identity)?;
1063 assert_relative_eq!(trace_identity, 4.0, epsilon = 1e-6);
1064
1065 Ok(())
1066 }
1067
1068 #[test]
1069 fn test_matrix_rank() -> TorshResult<()> {
1070 let mat = create_test_matrix_2x2()?;
1072 let rank = matrix_rank(&mat, None)?;
1073 assert!((1..=2).contains(&rank));
1076
1077 let identity = eye::<f32>(2)?;
1079 let rank_identity = matrix_rank(&identity, None)?;
1080 assert!((1..=2).contains(&rank_identity));
1081
1082 Ok(())
1083 }
1084
1085 #[test]
1086 fn test_matvec() -> TorshResult<()> {
1087 let identity = eye::<f32>(3)?;
1089 let vec = create_test_vector()?;
1090
1091 let result = matvec(&identity, &vec)?;
1092
1093 assert_eq!(result.shape().dims(), &[3]);
1095 for i in 0..3 {
1096 assert_relative_eq!(result.get(&[i])?, vec.get(&[i])?, epsilon = 1e-6);
1097 }
1098
1099 Ok(())
1100 }
1101
1102 #[test]
1103 fn test_vecmat() -> TorshResult<()> {
1104 let vec = create_test_vector()?;
1106 let identity = eye::<f32>(3)?;
1107
1108 let result = vecmat(&vec, &identity)?;
1109
1110 assert_eq!(result.shape().dims(), &[3]);
1112 for i in 0..3 {
1113 assert_relative_eq!(result.get(&[i])?, vec.get(&[i])?, epsilon = 1e-6);
1114 }
1115
1116 Ok(())
1117 }
1118
1119 #[test]
1120 fn test_matmul() -> TorshResult<()> {
1121 let mat1 = create_test_matrix_2x2()?;
1122 let mat2 = eye::<f32>(2)?;
1123
1124 let result = matmul(&mat1, &mat2)?;
1125
1126 assert_eq!(result.shape().dims(), &[2, 2]);
1128 for i in 0..2 {
1129 for j in 0..2 {
1130 assert_relative_eq!(result.get(&[i, j])?, mat1.get(&[i, j])?, epsilon = 1e-6);
1131 }
1132 }
1133
1134 Ok(())
1135 }
1136
1137 #[test]
1138 fn test_lu_decomposition() -> TorshResult<()> {
1139 let mat = create_test_matrix_2x2()?;
1140 let (p, l, u) = decomposition::lu(&mat)?;
1141
1142 assert_eq!(p.shape().dims(), &[2, 2]);
1144 assert_eq!(l.shape().dims(), &[2, 2]);
1145 assert_eq!(u.shape().dims(), &[2, 2]);
1146
1147 assert!(l.get(&[0, 1])?.abs() < 1e-6); assert!(l.get(&[0, 0])? > 0.0); assert!(l.get(&[1, 1])? > 0.0);
1156
1157 Ok(())
1158 }
1159
1160 #[test]
1161 fn test_qr_decomposition() -> TorshResult<()> {
1162 let mat = create_test_matrix_2x2()?;
1163 let (q, r) = decomposition::qr(&mat)?;
1164
1165 assert_eq!(q.shape().dims(), &[2, 2]);
1167 assert_eq!(r.shape().dims(), &[2, 2]);
1168
1169 let qr_product = matmul(&q, &r)?;
1171
1172 for i in 0..2 {
1173 for j in 0..2 {
1174 assert_relative_eq!(qr_product.get(&[i, j])?, mat.get(&[i, j])?, epsilon = 1e-4);
1175 }
1176 }
1177
1178 Ok(())
1179 }
1180
1181 #[test]
1182 fn test_cholesky_decomposition() -> TorshResult<()> {
1183 let data = vec![4.0f32, 2.0, 2.0, 3.0];
1186 let mat = Tensor::from_data(data, vec![2, 2], torsh_core::DeviceType::Cpu)?;
1187
1188 let l = decomposition::cholesky(&mat, false)?;
1189
1190 assert_eq!(l.shape().dims(), &[2, 2]);
1192
1193 let lt = l.transpose(-2, -1)?;
1195 let llt_product = matmul(&l, <)?;
1196
1197 for i in 0..2 {
1198 for j in 0..2 {
1199 assert_relative_eq!(llt_product.get(&[i, j])?, mat.get(&[i, j])?, epsilon = 1e-5);
1200 }
1201 }
1202
1203 Ok(())
1204 }
1205
1206 #[test]
1207 fn test_matrix_inverse() -> TorshResult<()> {
1208 let identity = eye::<f32>(2)?;
1210 let inv_identity = crate::solvers::inv(&identity)?;
1211
1212 assert_eq!(inv_identity.shape().dims(), &[2, 2]);
1214
1215 for i in 0..2 {
1217 for j in 0..2 {
1218 let expected = if i == j { 1.0 } else { 0.0 };
1219 assert_relative_eq!(inv_identity.get(&[i, j])?, expected, epsilon = 1e-6);
1220 }
1221 }
1222
1223 Ok(())
1224 }
1225
1226 #[test]
1227 fn test_condition_number() -> TorshResult<()> {
1228 let identity = eye::<f32>(3)?;
1230 let cond_num = cond(&identity, Some("2"))?;
1231 assert_relative_eq!(cond_num, 1.0, epsilon = 1e-5);
1232
1233 let mat = create_test_matrix_2x2()?;
1235 let cond_num = cond(&mat, Some("2"))?;
1236 assert!(cond_num > 1.0); assert!(cond_num < 100.0); Ok(())
1240 }
1241
1242 #[test]
1243 fn test_matrix_norms() -> TorshResult<()> {
1244 let mat = create_test_matrix_2x2()?;
1245
1246 let fro_norm = matrix_functions::matrix_norm(&mat, Some("fro"))?;
1248 assert!(fro_norm > 0.0);
1249
1250 let one_norm = matrix_functions::matrix_norm(&mat, Some("1"))?;
1252 assert!(one_norm > 0.0);
1253
1254 let inf_norm = matrix_functions::matrix_norm(&mat, Some("inf"))?;
1256 assert!(inf_norm > 0.0);
1257
1258 Ok(())
1259 }
1260
1261 #[test]
1262 fn test_matrix_analysis() -> TorshResult<()> {
1263 let identity = eye::<f32>(3)?;
1265 let analysis = analyze_matrix(&identity)?;
1266
1267 assert_eq!(analysis.dimensions, (3, 3));
1268 assert!(analysis.is_square);
1269 assert!(analysis.is_symmetric);
1270 assert!(analysis.is_diagonal);
1271 assert!(analysis.is_identity);
1272 assert!(analysis.is_positive_definite);
1273
1274 if let Some(det) = analysis.determinant {
1275 assert_relative_eq!(det, 1.0, epsilon = 1e-5);
1276 }
1277
1278 if let Some(tr) = analysis.trace {
1279 assert_relative_eq!(tr, 3.0, epsilon = 1e-5);
1280 }
1281
1282 assert!(analysis.recommended_solver.contains("Identity"));
1283 assert!(analysis.stability_assessment.contains("Excellent"));
1284
1285 let mat = create_test_matrix_2x2()?;
1287 let analysis = analyze_matrix(&mat)?;
1288
1289 assert_eq!(analysis.dimensions, (2, 2));
1290 assert!(analysis.is_square);
1291 assert!(!analysis.is_symmetric);
1292 assert!(!analysis.is_diagonal);
1293 assert!(!analysis.is_identity);
1294
1295 assert!(analysis.rank >= 1);
1296 assert!(analysis.norms.0 > 0.0); assert!(analysis.norms.1 > 0.0); assert!(analysis.norms.2 > 0.0); let rect_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
1302 let rect_mat = Tensor::from_data(rect_data, vec![2, 3], torsh_core::DeviceType::Cpu)?;
1303 let analysis = analyze_matrix(&rect_mat)?;
1304
1305 assert_eq!(analysis.dimensions, (2, 3));
1306 assert!(!analysis.is_square);
1307 assert!(analysis.determinant.is_none());
1308 assert!(analysis.trace.is_none());
1309 assert!(analysis.condition_number.is_none());
1310
1311 Ok(())
1312 }
1313
1314 #[test]
1315 fn test_allclose() -> TorshResult<()> {
1316 let data1 = vec![1.0f32, 2.0, 3.0, 4.0];
1317 let mat1 = Tensor::from_data(data1, vec![2, 2], torsh_core::DeviceType::Cpu)?;
1318
1319 let data2 = vec![1.0001f32, 1.9999, 3.0001, 3.9999];
1320 let mat2 = Tensor::from_data(data2, vec![2, 2], torsh_core::DeviceType::Cpu)?;
1321
1322 assert!(allclose(&mat1, &mat2, Some(1e-3), Some(1e-3))?);
1324
1325 assert!(!allclose(&mat1, &mat2, Some(1e-8), Some(1e-8))?);
1327
1328 let data3 = vec![1.0f32, 2.0, 3.0];
1330 let mat3 = Tensor::from_data(data3, vec![3], torsh_core::DeviceType::Cpu)?;
1331 assert!(!allclose(&mat1, &mat3, None, None)?);
1332
1333 assert!(allclose(&mat1, &mat1, None, None)?);
1335
1336 Ok(())
1337 }
1338
1339 #[test]
1340 fn test_isclose() -> TorshResult<()> {
1341 let data1 = vec![1.0f32, 2.0, 3.0, 4.0];
1342 let mat1 = Tensor::from_data(data1, vec![2, 2], torsh_core::DeviceType::Cpu)?;
1343
1344 let data2 = vec![1.0001f32, 2.1, 3.0001, 3.9999];
1345 let mat2 = Tensor::from_data(data2, vec![2, 2], torsh_core::DeviceType::Cpu)?;
1346
1347 let result = isclose(&mat1, &mat2, Some(1e-3), Some(1e-3))?;
1348
1349 assert_eq!(result.shape().dims(), &[2, 2]);
1351
1352 assert_relative_eq!(result.get(&[0, 0])?, 1.0, epsilon = 1e-6); assert_relative_eq!(result.get(&[0, 1])?, 0.0, epsilon = 1e-6); assert_relative_eq!(result.get(&[1, 0])?, 1.0, epsilon = 1e-6); assert_relative_eq!(result.get(&[1, 1])?, 1.0, epsilon = 1e-6); Ok(())
1359 }
1360
1361 #[test]
1362 fn test_matrix_equals() -> TorshResult<()> {
1363 let data1 = vec![1.0f32, 2.0, 3.0, 4.0];
1364 let mat1 = Tensor::from_data(data1.clone(), vec![2, 2], torsh_core::DeviceType::Cpu)?;
1365 let mat2 = Tensor::from_data(data1, vec![2, 2], torsh_core::DeviceType::Cpu)?;
1366
1367 assert!(matrix_equals(&mat1, &mat2)?);
1369
1370 let data3 = vec![1.0001f32, 2.0, 3.0, 4.0];
1371 let mat3 = Tensor::from_data(data3, vec![2, 2], torsh_core::DeviceType::Cpu)?;
1372
1373 assert!(!matrix_equals(&mat1, &mat3)?);
1375
1376 Ok(())
1377 }
1378
1379 #[test]
1380 fn test_frobenius_distance() -> TorshResult<()> {
1381 let data1 = vec![1.0f32, 2.0, 3.0, 4.0];
1382 let mat1 = Tensor::from_data(data1, vec![2, 2], torsh_core::DeviceType::Cpu)?;
1383
1384 let data2 = vec![2.0f32, 3.0, 4.0, 5.0];
1385 let mat2 = Tensor::from_data(data2, vec![2, 2], torsh_core::DeviceType::Cpu)?;
1386
1387 let distance = frobenius_distance(&mat1, &mat2)?;
1388
1389 assert_relative_eq!(distance, 2.0, epsilon = 1e-6);
1391
1392 let zero_distance = frobenius_distance(&mat1, &mat1)?;
1394 assert_relative_eq!(zero_distance, 0.0, epsilon = 1e-6);
1395
1396 Ok(())
1397 }
1398
1399 #[test]
1400 fn test_is_symmetric() -> TorshResult<()> {
1401 let sym_data = vec![1.0f32, 2.0, 2.0, 3.0];
1403 let sym_mat = Tensor::from_data(sym_data, vec![2, 2], torsh_core::DeviceType::Cpu)?;
1404
1405 assert!(is_symmetric(&sym_mat, None)?);
1406
1407 let nonsym_data = vec![1.0f32, 2.0, 3.0, 4.0];
1409 let nonsym_mat = Tensor::from_data(nonsym_data, vec![2, 2], torsh_core::DeviceType::Cpu)?;
1410
1411 assert!(!is_symmetric(&nonsym_mat, None)?);
1412
1413 let rect_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
1415 let rect_mat = Tensor::from_data(rect_data, vec![2, 3], torsh_core::DeviceType::Cpu)?;
1416
1417 assert!(!is_symmetric(&rect_mat, None)?);
1418
1419 let approx_sym_data = vec![1.0f32, 2.0, 2.0001, 3.0];
1421 let approx_sym_mat =
1422 Tensor::from_data(approx_sym_data, vec![2, 2], torsh_core::DeviceType::Cpu)?;
1423
1424 assert!(is_symmetric(&approx_sym_mat, Some(1e-3))?);
1425 assert!(!is_symmetric(&approx_sym_mat, Some(1e-6))?);
1426
1427 Ok(())
1428 }
1429
1430 #[test]
1431 fn test_hadamard_product() -> TorshResult<()> {
1432 let a = Tensor::from_data(
1434 vec![1.0, 2.0, 3.0, 4.0],
1435 vec![2, 2],
1436 torsh_core::DeviceType::Cpu,
1437 )?;
1438 let b = Tensor::from_data(
1439 vec![5.0, 6.0, 7.0, 8.0],
1440 vec![2, 2],
1441 torsh_core::DeviceType::Cpu,
1442 )?;
1443
1444 let h = hadamard(&a, &b)?;
1445
1446 assert_relative_eq!(h.get(&[0, 0])?, 5.0, epsilon = 1e-6);
1448 assert_relative_eq!(h.get(&[0, 1])?, 12.0, epsilon = 1e-6);
1449 assert_relative_eq!(h.get(&[1, 0])?, 21.0, epsilon = 1e-6);
1450 assert_relative_eq!(h.get(&[1, 1])?, 32.0, epsilon = 1e-6);
1451
1452 let h2 = hadamard(&b, &a)?;
1454 for i in 0..2 {
1455 for j in 0..2 {
1456 assert_relative_eq!(h.get(&[i, j])?, h2.get(&[i, j])?, epsilon = 1e-6);
1457 }
1458 }
1459
1460 Ok(())
1461 }
1462
1463 #[test]
1464 fn test_vec_unvec_roundtrip() -> TorshResult<()> {
1465 let original = Tensor::from_data(
1467 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
1468 vec![2, 3],
1469 torsh_core::DeviceType::Cpu,
1470 )?;
1471
1472 let v = vec_matrix(&original)?;
1474 assert_eq!(v.shape().dims(), &[6]);
1475
1476 assert_relative_eq!(v.get(&[0])?, 1.0, epsilon = 1e-6);
1483 assert_relative_eq!(v.get(&[1])?, 4.0, epsilon = 1e-6);
1484 assert_relative_eq!(v.get(&[2])?, 2.0, epsilon = 1e-6);
1485 assert_relative_eq!(v.get(&[3])?, 5.0, epsilon = 1e-6);
1486
1487 let reconstructed = unvec_matrix(&v, 2, 3)?;
1489 assert_eq!(reconstructed.shape().dims(), &[2, 3]);
1490
1491 for i in 0..2 {
1493 for j in 0..3 {
1494 assert_relative_eq!(
1495 original.get(&[i, j])?,
1496 reconstructed.get(&[i, j])?,
1497 epsilon = 1e-6
1498 );
1499 }
1500 }
1501
1502 Ok(())
1503 }
1504
1505 #[test]
1506 fn test_commutator() -> TorshResult<()> {
1507 let a = Tensor::from_data(
1509 vec![1.0, 2.0, 3.0, 4.0],
1510 vec![2, 2],
1511 torsh_core::DeviceType::Cpu,
1512 )?;
1513 let b = Tensor::from_data(
1514 vec![0.0, 1.0, 1.0, 0.0],
1515 vec![2, 2],
1516 torsh_core::DeviceType::Cpu,
1517 )?;
1518
1519 let comm = commutator(&a, &b)?;
1520
1521 let ab = a.matmul(&b)?;
1523 let ba = b.matmul(&a)?;
1524 let expected = ab.sub(&ba)?;
1525
1526 for i in 0..2 {
1527 for j in 0..2 {
1528 assert_relative_eq!(comm.get(&[i, j])?, expected.get(&[i, j])?, epsilon = 1e-6);
1529 }
1530 }
1531
1532 let comm_ba = commutator(&b, &a)?;
1534 for i in 0..2 {
1535 for j in 0..2 {
1536 assert_relative_eq!(comm.get(&[i, j])?, -comm_ba.get(&[i, j])?, epsilon = 1e-6);
1537 }
1538 }
1539
1540 let comm_aa = commutator(&a, &a)?;
1542 for i in 0..2 {
1543 for j in 0..2 {
1544 assert_relative_eq!(comm_aa.get(&[i, j])?, 0.0, epsilon = 1e-6);
1545 }
1546 }
1547
1548 Ok(())
1549 }
1550
1551 #[test]
1552 fn test_anticommutator() -> TorshResult<()> {
1553 let a = Tensor::from_data(
1555 vec![1.0, 2.0, 3.0, 4.0],
1556 vec![2, 2],
1557 torsh_core::DeviceType::Cpu,
1558 )?;
1559 let b = Tensor::from_data(
1560 vec![0.0, 1.0, 1.0, 0.0],
1561 vec![2, 2],
1562 torsh_core::DeviceType::Cpu,
1563 )?;
1564
1565 let anticomm = anticommutator(&a, &b)?;
1566
1567 let ab = a.matmul(&b)?;
1569 let ba = b.matmul(&a)?;
1570 let expected = ab.add(&ba)?;
1571
1572 for i in 0..2 {
1573 for j in 0..2 {
1574 assert_relative_eq!(
1575 anticomm.get(&[i, j])?,
1576 expected.get(&[i, j])?,
1577 epsilon = 1e-6
1578 );
1579 }
1580 }
1581
1582 let anticomm_ba = anticommutator(&b, &a)?;
1584 for i in 0..2 {
1585 for j in 0..2 {
1586 assert_relative_eq!(
1587 anticomm.get(&[i, j])?,
1588 anticomm_ba.get(&[i, j])?,
1589 epsilon = 1e-6
1590 );
1591 }
1592 }
1593
1594 let anticomm_aa = anticommutator(&a, &a)?;
1596 let a_squared = a.matmul(&a)?;
1597 for i in 0..2 {
1598 for j in 0..2 {
1599 assert_relative_eq!(
1600 anticomm_aa.get(&[i, j])?,
1601 2.0 * a_squared.get(&[i, j])?,
1602 epsilon = 1e-6
1603 );
1604 }
1605 }
1606
1607 Ok(())
1608 }
1609}
1610
1611pub mod prelude {
1613 pub use crate::numerical_stability::{
1614 check_numerical_stability, equilibrate_matrix, unequilibrate_solution,
1615 EquilibrationStrategy, ScalingFactors, StabilityConfig,
1616 };
1617 pub use crate::{
1618 advanced_ops::*, comparison::*, decomposition::*, matrix_functions::*, randomized::*,
1619 solvers::*, sparse::*, special_matrices::*, taylor::*, utils::*,
1620 };
1621}