scirs2_linalg/matrix_functions/
utils.rs

1//! Utility functions shared across matrix function modules
2
3use scirs2_core::ndarray::{Array2, ArrayView2};
4use scirs2_core::numeric::{Float, NumAssign, One};
5use std::iter::Sum;
6
7use crate::error::{LinalgError, LinalgResult};
8
9/// Check if a floating point number is close to an integer
10pub fn is_integer<F: Float>(x: F) -> bool {
11    (x - x.round()).abs() < F::from(1e-10).unwrap_or(F::epsilon())
12}
13
14/// Check if a matrix is diagonal
15pub fn is_diagonal<F>(a: &ArrayView2<F>) -> bool
16where
17    F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
18{
19    let n = a.nrows();
20    for i in 0..n {
21        for j in 0..n {
22            if i != j && a[[i, j]].abs() > F::epsilon() {
23                return false;
24            }
25        }
26    }
27    true
28}
29
30/// Check if a matrix is symmetric
31pub fn is_symmetric<F>(a: &ArrayView2<F>) -> bool
32where
33    F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
34{
35    let n = a.nrows();
36    if n != a.ncols() {
37        return false;
38    }
39
40    for i in 0..n {
41        for j in 0..n {
42            if (a[[i, j]] - a[[j, i]]).abs() > F::epsilon() {
43                return false;
44            }
45        }
46    }
47    true
48}
49
50/// Check if a matrix is the zero matrix
51pub fn is_zero_matrix<F>(a: &ArrayView2<F>) -> bool
52where
53    F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
54{
55    let (m, n) = a.dim();
56    for i in 0..m {
57        for j in 0..n {
58            if a[[i, j]].abs() > F::epsilon() {
59                return false;
60            }
61        }
62    }
63    true
64}
65
66/// Check if a matrix is the identity matrix
67pub fn is_identity<F>(a: &ArrayView2<F>) -> bool
68where
69    F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
70{
71    let n = a.nrows();
72    if n != a.ncols() {
73        return false;
74    }
75
76    for i in 0..n {
77        for j in 0..n {
78            let expected = if i == j { F::one() } else { F::zero() };
79            if (a[[i, j]] - expected).abs() > F::epsilon() {
80                return false;
81            }
82        }
83    }
84    true
85}
86
87/// Compute matrix multiplication C = A * B
88pub fn matrix_multiply<F>(a: &ArrayView2<F>, b: &ArrayView2<F>) -> LinalgResult<Array2<F>>
89where
90    F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
91{
92    let (m, k1) = a.dim();
93    let (k2, n) = b.dim();
94
95    if k1 != k2 {
96        return Err(LinalgError::ShapeError(format!(
97            "Matrix dimensions incompatible for multiplication: ({}, {}) × ({}, {})",
98            m, k1, k2, n
99        )));
100    }
101
102    let mut c = Array2::<F>::zeros((m, n));
103    for i in 0..m {
104        for j in 0..n {
105            for k in 0..k1 {
106                c[[i, j]] += a[[i, k]] * b[[k, j]];
107            }
108        }
109    }
110
111    Ok(c)
112}
113
114/// Compute matrix power for small integer powers using repeated squaring
115pub fn integer_matrix_power<F>(a: &ArrayView2<F>, p: i32) -> LinalgResult<Array2<F>>
116where
117    F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
118{
119    use crate::solve::solve_multiple;
120
121    let n = a.nrows();
122
123    if p == 0 {
124        return Ok(Array2::eye(n));
125    }
126
127    if p == 1 {
128        return Ok(a.to_owned());
129    }
130
131    if p < 0 {
132        // Compute A^{-|p|} = (A^{-1})^{|p|}
133        let a_inv = solve_multiple(a, &Array2::eye(n).view(), None)?;
134        return integer_matrix_power(&a_inv.view(), -p);
135    }
136
137    // Use repeated squaring for positive powers
138    let mut result = Array2::eye(n);
139    let mut base = a.to_owned();
140    let mut exp = p as u32;
141
142    while exp > 0 {
143        if exp % 2 == 1 {
144            result = matrix_multiply(&result.view(), &base.view())?;
145        }
146        base = matrix_multiply(&base.view(), &base.view())?;
147        exp /= 2;
148    }
149
150    Ok(result)
151}
152
153/// Compute the Frobenius norm of a matrix
154pub fn frobenius_norm<F>(a: &ArrayView2<F>) -> F
155where
156    F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
157{
158    let (m, n) = a.dim();
159    let mut sum = F::zero();
160
161    for i in 0..m {
162        for j in 0..n {
163            sum += a[[i, j]] * a[[i, j]];
164        }
165    }
166
167    sum.sqrt()
168}
169
170/// Compute the maximum absolute difference between two matrices
171pub fn matrix_diff_norm<F>(a: &ArrayView2<F>, b: &ArrayView2<F>) -> LinalgResult<F>
172where
173    F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
174{
175    if a.dim() != b.dim() {
176        return Err(LinalgError::ShapeError(
177            "Matrices must have the same dimensions".to_string(),
178        ));
179    }
180
181    let (m, n) = a.dim();
182    let mut max_diff = F::zero();
183
184    for i in 0..m {
185        for j in 0..n {
186            let diff = (a[[i, j]] - b[[i, j]]).abs();
187            if diff > max_diff {
188                max_diff = diff;
189            }
190        }
191    }
192
193    Ok(max_diff)
194}
195
196/// Scale a matrix by a scalar: result = alpha * A
197pub fn scale_matrix<F>(a: &ArrayView2<F>, alpha: F) -> Array2<F>
198where
199    F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
200{
201    let (m, n) = a.dim();
202    let mut result = Array2::<F>::zeros((m, n));
203
204    for i in 0..m {
205        for j in 0..n {
206            result[[i, j]] = alpha * a[[i, j]];
207        }
208    }
209
210    result
211}
212
213/// Add two matrices: result = A + B
214pub fn matrix_add<F>(a: &ArrayView2<F>, b: &ArrayView2<F>) -> LinalgResult<Array2<F>>
215where
216    F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
217{
218    if a.dim() != b.dim() {
219        return Err(LinalgError::ShapeError(
220            "Matrices must have the same dimensions for addition".to_string(),
221        ));
222    }
223
224    let (m, n) = a.dim();
225    let mut result = Array2::<F>::zeros((m, n));
226
227    for i in 0..m {
228        for j in 0..n {
229            result[[i, j]] = a[[i, j]] + b[[i, j]];
230        }
231    }
232
233    Ok(result)
234}
235
236/// Subtract two matrices: result = A - B
237pub fn matrix_subtract<F>(a: &ArrayView2<F>, b: &ArrayView2<F>) -> LinalgResult<Array2<F>>
238where
239    F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
240{
241    if a.dim() != b.dim() {
242        return Err(LinalgError::ShapeError(
243            "Matrices must have the same dimensions for subtraction".to_string(),
244        ));
245    }
246
247    let (m, n) = a.dim();
248    let mut result = Array2::<F>::zeros((m, n));
249
250    for i in 0..m {
251        for j in 0..n {
252            result[[i, j]] = a[[i, j]] - b[[i, j]];
253        }
254    }
255
256    Ok(result)
257}
258
259/// Compute the matrix transpose
260pub fn matrix_transpose<F>(a: &ArrayView2<F>) -> Array2<F>
261where
262    F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
263{
264    let (m, n) = a.dim();
265    let mut result = Array2::<F>::zeros((n, m));
266
267    for i in 0..m {
268        for j in 0..n {
269            result[[j, i]] = a[[i, j]];
270        }
271    }
272
273    result
274}
275
276/// Check if all eigenvalues are positive (for positive definite matrices)
277pub fn check_positive_definite<F>(eigenvals: &[F]) -> bool
278where
279    F: Float,
280{
281    eigenvals.iter().all(|&val| val > F::zero())
282}
283
284/// Check if all eigenvalues are non-negative (for positive semidefinite matrices)
285pub fn check_positive_semidefinite<F>(eigenvals: &[F]) -> bool
286where
287    F: Float,
288{
289    eigenvals.iter().all(|&val| val >= F::zero())
290}