scirs2_linalg/
basic.rs

1//! Basic matrix operations
2
3use crate::error::{LinalgError, LinalgResult};
4use scirs2_core::ndarray::{Array2, ArrayView2, ScalarOperand};
5use scirs2_core::numeric::{Float, NumAssign};
6use std::iter::Sum;
7
8/// Compute the determinant of a square matrix.
9///
10/// # Arguments
11///
12/// * `a` - Input square matrix
13/// * `workers` - Number of worker threads (None = use default)
14///
15/// # Returns
16///
17/// * Determinant of the matrix
18///
19/// # Examples
20///
21/// ```
22/// use scirs2_core::ndarray::array;
23/// use scirs2_linalg::det;
24///
25/// let a = array![[1.0_f64, 2.0], [3.0, 4.0]];
26/// let d = det(&a.view(), None).unwrap();
27/// assert!((d - (-2.0)).abs() < 1e-10);
28/// ```
29#[allow(dead_code)]
30pub fn det<F>(a: &ArrayView2<F>, workers: Option<usize>) -> LinalgResult<F>
31where
32    F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
33{
34    use crate::parallel;
35
36    // Configure workers for parallel operations
37    parallel::configure_workers(workers);
38
39    if a.nrows() != a.ncols() {
40        let rows = a.nrows();
41        let cols = a.ncols();
42        return Err(LinalgError::ShapeError(format!(
43            "Determinant computation failed: Matrix must be square\nMatrix shape: {rows}×{cols}\nExpected: Square matrix (n×n)"
44        )));
45    }
46
47    // Simple implementation for 2x2 and 3x3 matrices
48    match a.nrows() {
49        0 => Ok(F::one()),
50        1 => Ok(a[[0, 0]]),
51        2 => Ok(a[[0, 0]] * a[[1, 1]] - a[[0, 1]] * a[[1, 0]]),
52        3 => {
53            let det = a[[0, 0]] * (a[[1, 1]] * a[[2, 2]] - a[[1, 2]] * a[[2, 1]])
54                - a[[0, 1]] * (a[[1, 0]] * a[[2, 2]] - a[[1, 2]] * a[[2, 0]])
55                + a[[0, 2]] * (a[[1, 0]] * a[[2, 1]] - a[[1, 1]] * a[[2, 0]]);
56            Ok(det)
57        }
58        _ => {
59            // For larger matrices, use LU decomposition
60            use crate::decomposition::lu;
61
62            match lu(a, workers) {
63                Ok((p, _l, u)) => {
64                    // Calculate the determinant as the product of diagonal elements of U
65                    let mut det_u = F::one();
66                    for i in 0..u.nrows() {
67                        det_u *= u[[i, i]];
68                    }
69
70                    // Count the number of row swaps in the permutation matrix
71                    let mut swap_count = 0;
72                    for i in 0..p.nrows() {
73                        for j in 0..i {
74                            if p[[i, j]] == F::one() {
75                                swap_count += 1;
76                            }
77                        }
78                    }
79
80                    // Determinant is (-1)^swaps * det(U)
81                    if swap_count % 2 == 0 {
82                        Ok(det_u)
83                    } else {
84                        Ok(-det_u)
85                    }
86                }
87                Err(LinalgError::SingularMatrixError(_)) => {
88                    // Singular matrix has determinant zero
89                    Ok(F::zero())
90                }
91                Err(e) => Err(e),
92            }
93        }
94    }
95}
96
97/// Compute the inverse of a square matrix.
98///
99/// # Arguments
100///
101/// * `a` - Input square matrix
102/// * `workers` - Number of worker threads (None = use default)
103///
104/// # Returns
105///
106/// * Inverse of the matrix
107///
108/// # Examples
109///
110/// ```
111/// use scirs2_core::ndarray::array;
112/// use scirs2_linalg::inv;
113///
114/// let a = array![[1.0_f64, 0.0], [0.0, 2.0]];
115/// let a_inv = inv(&a.view(), None).unwrap();
116/// assert!((a_inv[[0, 0]] - 1.0).abs() < 1e-10);
117/// assert!((a_inv[[1, 1]] - 0.5).abs() < 1e-10);
118/// ```
119#[allow(dead_code)]
120pub fn inv<F>(a: &ArrayView2<F>, workers: Option<usize>) -> LinalgResult<Array2<F>>
121where
122    F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
123{
124    use crate::parallel;
125
126    // Configure workers for parallel operations
127    parallel::configure_workers(workers);
128
129    if a.nrows() != a.ncols() {
130        let rows = a.nrows();
131        let cols = a.ncols();
132        return Err(LinalgError::ShapeError(format!(
133            "Matrix inverse computation failed: Matrix must be square\nMatrix shape: {rows}×{cols}\nExpected: Square matrix (n×n)"
134        )));
135    }
136
137    // Simple implementation for 2x2 matrices
138    if a.nrows() == 2 {
139        let det_val = det(a, workers)?;
140        if det_val.abs() < F::epsilon() {
141            // Calculate condition number estimate for 2x2 matrix
142            let norm_a = (a[[0, 0]] * a[[0, 0]]
143                + a[[0, 1]] * a[[0, 1]]
144                + a[[1, 0]] * a[[1, 0]]
145                + a[[1, 1]] * a[[1, 1]])
146            .sqrt();
147            let cond_estimate = if det_val.abs() > F::zero() {
148                Some((norm_a / det_val.abs()).to_f64().unwrap_or(1e16))
149            } else {
150                None
151            };
152
153            return Err(LinalgError::singularmatrix_with_suggestions(
154                "matrix inverse",
155                a.dim(),
156                cond_estimate,
157            ));
158        }
159
160        let inv_det = F::one() / det_val;
161        let mut result = Array2::zeros((2, 2));
162        result[[0, 0]] = a[[1, 1]] * inv_det;
163        result[[0, 1]] = -a[[0, 1]] * inv_det;
164        result[[1, 0]] = -a[[1, 0]] * inv_det;
165        result[[1, 1]] = a[[0, 0]] * inv_det;
166        return Ok(result);
167    }
168
169    // For larger matrices, use the solve_multiple function with an identity matrix
170    use crate::solve::solve_multiple;
171
172    let n = a.nrows();
173    let mut identity = Array2::zeros((n, n));
174    for i in 0..n {
175        identity[[i, i]] = F::one();
176    }
177
178    // Solve A * X = I to get X = A^(-1)
179    match solve_multiple(a, &identity.view(), workers) {
180        Err(LinalgError::SingularMatrixError(_)) => {
181            // Use enhanced error with regularization suggestions
182            Err(LinalgError::singularmatrix_with_suggestions(
183                "matrix inverse via solve",
184                a.dim(),
185                None, // Could compute condition number here for better diagnostics
186            ))
187        }
188        other => other,
189    }
190}
191
192/// Raise a square matrix to the given power.
193///
194/// # Arguments
195///
196/// * `a` - Input square matrix
197/// * `n` - Power (can be positive, negative, or zero)
198/// * `workers` - Number of worker threads (None = use default)
199///
200/// # Returns
201///
202/// * Matrix raised to the power n
203///
204/// # Examples
205///
206/// ```
207/// use scirs2_core::ndarray::array;
208/// use scirs2_linalg::matrix_power;
209///
210/// let a = array![[1.0_f64, 2.0], [3.0, 4.0]];
211///
212/// // Identity matrix for n=0
213/// let a_0 = matrix_power(&a.view(), 0, None).unwrap();
214/// assert!((a_0[[0, 0]] - 1.0).abs() < 1e-10);
215/// assert!((a_0[[0, 1]] - 0.0).abs() < 1e-10);
216/// assert!((a_0[[1, 0]] - 0.0).abs() < 1e-10);
217/// assert!((a_0[[1, 1]] - 1.0).abs() < 1e-10);
218/// ```
219#[allow(dead_code)]
220pub fn matrix_power<F>(a: &ArrayView2<F>, n: i32, workers: Option<usize>) -> LinalgResult<Array2<F>>
221where
222    F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
223{
224    use crate::parallel;
225
226    // Configure workers for parallel operations
227    parallel::configure_workers(workers);
228
229    if a.nrows() != a.ncols() {
230        let rows = a.nrows();
231        let cols = a.ncols();
232        return Err(LinalgError::ShapeError(format!(
233            "Matrix power computation failed: Matrix must be square\nMatrix shape: {rows}×{cols}\nExpected: Square matrix (n×n)"
234        )));
235    }
236
237    let dim = a.nrows();
238
239    // Handle special cases
240    if n == 0 {
241        // Return identity matrix
242        let mut result = Array2::zeros((dim, dim));
243        for i in 0..dim {
244            result[[i, i]] = F::one();
245        }
246        return Ok(result);
247    }
248
249    if n == 1 {
250        // Return copy of the matrix
251        return Ok(a.to_owned());
252    }
253
254    if n == -1 {
255        // Return inverse
256        return inv(a, workers);
257    }
258
259    if n.abs() > 1 {
260        // For higher powers, we would implement more efficient algorithms
261        // using matrix decompositions or binary exponentiation
262        // This is a placeholder that will be replaced with a proper implementation
263        return Err(LinalgError::NotImplementedError(
264            "Matrix power for |n| > 1 not yet implemented".to_string(),
265        ));
266    }
267
268    // This should never be reached
269    Err(LinalgError::ComputationError(
270        "Unexpected error in matrix power calculation".to_string(),
271    ))
272}
273
274/// Compute the trace of a square matrix.
275///
276/// The trace is the sum of the diagonal elements.
277///
278/// # Arguments
279///
280/// * `a` - A square matrix
281///
282/// # Returns
283///
284/// * Trace of the matrix
285///
286/// # Examples
287///
288/// ```
289/// use scirs2_core::ndarray::array;
290/// use scirs2_linalg::basic_trace;
291///
292/// let a = array![[1.0_f64, 2.0], [3.0, 4.0]];
293/// let tr = basic_trace(&a.view()).unwrap();
294/// assert!((tr - 5.0).abs() < 1e-10);
295/// ```
296#[allow(dead_code)]
297pub fn trace<F>(a: &ArrayView2<F>) -> LinalgResult<F>
298where
299    F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
300{
301    if a.nrows() != a.ncols() {
302        let rows = a.nrows();
303        let cols = a.ncols();
304        return Err(LinalgError::ShapeError(format!(
305            "Matrix trace computation failed: Matrix must be square\nMatrix shape: {rows}×{cols}\nExpected: Square matrix (n×n)"
306        )));
307    }
308
309    let mut tr = F::zero();
310    for i in 0..a.nrows() {
311        tr += a[[i, i]];
312    }
313
314    Ok(tr)
315}
316
317//
318// Backward compatibility wrapper functions
319//
320
321/// Compute the determinant of a square matrix (backward compatibility wrapper).
322///
323/// This is a convenience function that calls `det` with `workers = None`.
324/// For new code, prefer using `det` directly with explicit workers parameter.
325#[allow(dead_code)]
326pub fn det_default<F>(a: &ArrayView2<F>) -> LinalgResult<F>
327where
328    F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
329{
330    det(a, None)
331}
332
333/// Compute the inverse of a square matrix (backward compatibility wrapper).
334///
335/// This is a convenience function that calls `inv` with `workers = None`.
336/// For new code, prefer using `inv` directly with explicit workers parameter.
337#[allow(dead_code)]
338pub fn inv_default<F>(a: &ArrayView2<F>) -> LinalgResult<Array2<F>>
339where
340    F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
341{
342    inv(a, None)
343}
344
345/// Raise a square matrix to the given power (backward compatibility wrapper).
346///
347/// This is a convenience function that calls `matrix_power` with `workers = None`.
348/// For new code, prefer using `matrix_power` directly with explicit workers parameter.
349#[allow(dead_code)]
350pub fn matrix_power_default<F>(a: &ArrayView2<F>, n: i32) -> LinalgResult<Array2<F>>
351where
352    F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
353{
354    matrix_power(a, n, None)
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360    use approx::assert_relative_eq;
361    use scirs2_core::ndarray::array;
362
363    #[test]
364    fn test_det_2x2() {
365        let a = array![[1.0, 2.0], [3.0, 4.0]];
366        let d = det(&a.view(), None).unwrap();
367        assert!((d - (-2.0)).abs() < 1e-10);
368
369        let b = array![[2.0, 0.0], [0.0, 3.0]];
370        let d = det(&b.view(), None).unwrap();
371        assert!((d - 6.0).abs() < 1e-10);
372    }
373
374    #[test]
375    fn test_det_3x3() {
376        let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
377        let d = det(&a.view(), None).unwrap();
378        assert!((d - 0.0).abs() < 1e-10);
379
380        let b = array![[2.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 4.0]];
381        let d = det(&b.view(), None).unwrap();
382        assert!((d - 24.0).abs() < 1e-10);
383    }
384
385    #[test]
386    fn test_inv_2x2() {
387        let a = array![[1.0, 0.0], [0.0, 2.0]];
388        let a_inv = inv(&a.view(), None).unwrap();
389        assert_relative_eq!(a_inv[[0, 0]], 1.0);
390        assert_relative_eq!(a_inv[[0, 1]], 0.0);
391        assert_relative_eq!(a_inv[[1, 0]], 0.0);
392        assert_relative_eq!(a_inv[[1, 1]], 0.5);
393
394        let b = array![[1.0, 2.0], [3.0, 4.0]];
395        let b_inv = inv(&b.view(), None).unwrap();
396        assert_relative_eq!(b_inv[[0, 0]], -2.0);
397        assert_relative_eq!(b_inv[[0, 1]], 1.0);
398        assert_relative_eq!(b_inv[[1, 0]], 1.5);
399        assert_relative_eq!(b_inv[[1, 1]], -0.5);
400    }
401
402    #[test]
403    fn test_inv_large() {
404        // Test 3x3 matrix
405        let a = array![[1.0, 2.0, 3.0], [0.0, 1.0, 4.0], [5.0, 6.0, 0.0]];
406        let a_inv = inv(&a.view(), None).unwrap();
407
408        // Verify A * A^(-1) = I
409        let product = a.dot(&a_inv);
410        let n = a.nrows();
411        for i in 0..n {
412            for j in 0..n {
413                if i == j {
414                    assert_relative_eq!(product[[i, j]], 1.0, epsilon = 1e-10);
415                } else {
416                    assert_relative_eq!(product[[i, j]], 0.0, epsilon = 1e-10);
417                }
418            }
419        }
420
421        // Test 4x4 diagonal matrix
422        let b = array![
423            [2.0, 0.0, 0.0, 0.0],
424            [0.0, 3.0, 0.0, 0.0],
425            [0.0, 0.0, 4.0, 0.0],
426            [0.0, 0.0, 0.0, 5.0]
427        ];
428        let b_inv = inv(&b.view(), None).unwrap();
429        assert_relative_eq!(b_inv[[0, 0]], 0.5, epsilon = 1e-10);
430        assert_relative_eq!(b_inv[[1, 1]], 1.0 / 3.0, epsilon = 1e-10);
431        assert_relative_eq!(b_inv[[2, 2]], 0.25, epsilon = 1e-10);
432        assert_relative_eq!(b_inv[[3, 3]], 0.2, epsilon = 1e-10);
433
434        // Test singular matrix should error
435        let c = array![[1.0, 2.0, 3.0], [2.0, 4.0, 6.0], [3.0, 6.0, 9.0]];
436        assert!(inv(&c.view(), None).is_err());
437    }
438
439    #[test]
440    fn testmatrix_power() {
441        let a = array![[1.0, 2.0], [3.0, 4.0]];
442
443        // Power 0 should give identity matrix
444        let a_0 = matrix_power(&a.view(), 0, None).unwrap();
445        assert_relative_eq!(a_0[[0, 0]], 1.0);
446        assert_relative_eq!(a_0[[0, 1]], 0.0);
447        assert_relative_eq!(a_0[[1, 0]], 0.0);
448        assert_relative_eq!(a_0[[1, 1]], 1.0);
449
450        // Power 1 should return the original matrix
451        let a_1 = matrix_power(&a.view(), 1, None).unwrap();
452        assert_relative_eq!(a_1[[0, 0]], a[[0, 0]]);
453        assert_relative_eq!(a_1[[0, 1]], a[[0, 1]]);
454        assert_relative_eq!(a_1[[1, 0]], a[[1, 0]]);
455        assert_relative_eq!(a_1[[1, 1]], a[[1, 1]]);
456    }
457
458    #[test]
459    fn test_det_large() {
460        // Test 4x4 matrix
461        let a = array![
462            [2.0, 1.0, 0.0, 0.0],
463            [1.0, 2.0, 1.0, 0.0],
464            [0.0, 1.0, 2.0, 1.0],
465            [0.0, 0.0, 1.0, 2.0]
466        ];
467        let d = det(&a.view(), None).unwrap();
468        assert_relative_eq!(d, 5.0, epsilon = 1e-10);
469
470        // Test 5x5 diagonal matrix
471        let b = array![
472            [1.0, 0.0, 0.0, 0.0, 0.0],
473            [0.0, 2.0, 0.0, 0.0, 0.0],
474            [0.0, 0.0, 3.0, 0.0, 0.0],
475            [0.0, 0.0, 0.0, 4.0, 0.0],
476            [0.0, 0.0, 0.0, 0.0, 5.0]
477        ];
478        let d = det(&b.view(), None).unwrap();
479        assert_relative_eq!(d, 120.0, epsilon = 1e-10);
480
481        // Test singular matrix
482        let c = array![
483            [1.0, 2.0, 3.0, 4.0],
484            [2.0, 4.0, 6.0, 8.0],
485            [3.0, 6.0, 9.0, 12.0],
486            [4.0, 8.0, 12.0, 16.0]
487        ];
488        let d = det(&c.view(), None).unwrap();
489        assert_relative_eq!(d, 0.0, epsilon = 1e-10);
490    }
491}