scirs2_sparse/
utils.rs

1//! Utility functions for sparse matrices
2//!
3//! This module provides utility functions for sparse matrices.
4
5use crate::csr::CsrMatrix;
6use crate::error::{SparseError, SparseResult};
7use num_traits::Zero;
8
9/// Create an identity matrix in CSR format
10///
11/// # Arguments
12///
13/// * `n` - Size of the matrix (n x n)
14///
15/// # Returns
16///
17/// * Identity matrix in CSR format
18///
19/// # Example
20///
21/// ```
22/// use scirs2_sparse::utils::identity;
23///
24/// // Create a 3x3 identity matrix
25/// let eye = identity(3).unwrap();
26/// ```
27#[allow(dead_code)]
28pub fn identity(n: usize) -> SparseResult<CsrMatrix<f64>> {
29    if n == 0 {
30        return Err(SparseError::ValueError(
31            "Matrix size must be positive".to_string(),
32        ));
33    }
34
35    let mut data = Vec::with_capacity(n);
36    let mut row_indices = Vec::with_capacity(n);
37    let mut col_indices = Vec::with_capacity(n);
38
39    for i in 0..n {
40        data.push(1.0);
41        row_indices.push(i);
42        col_indices.push(i);
43    }
44
45    CsrMatrix::new(data, row_indices, col_indices, (n, n))
46}
47
48/// Create a diagonal matrix in CSR format
49///
50/// # Arguments
51///
52/// * `diag` - Vector of diagonal elements
53///
54/// # Returns
55///
56/// * Diagonal matrix in CSR format
57///
58/// # Example
59///
60/// ```
61/// use scirs2_sparse::utils::diag;
62///
63/// // Create a diagonal matrix with elements [1, 2, 3]
64/// let d = diag(&[1.0, 2.0, 3.0]).unwrap();
65/// ```
66#[allow(dead_code)]
67pub fn diag(diag: &[f64]) -> SparseResult<CsrMatrix<f64>> {
68    if diag.is_empty() {
69        return Err(SparseError::ValueError(
70            "Diagonal vector must not be empty".to_string(),
71        ));
72    }
73
74    let n = diag.len();
75    let mut data = Vec::with_capacity(n);
76    let mut row_indices = Vec::with_capacity(n);
77    let mut col_indices = Vec::with_capacity(n);
78
79    for (i, &val) in diag.iter().enumerate() {
80        if val != 0.0 {
81            data.push(val);
82            row_indices.push(i);
83            col_indices.push(i);
84        }
85    }
86
87    CsrMatrix::new(data, row_indices, col_indices, (n, n))
88}
89
90/// Calculate the density of a sparse matrix
91///
92/// # Arguments
93///
94/// * `shape` - Matrix shape (rows, cols)
95/// * `nnz` - Number of non-zero elements
96///
97/// # Returns
98///
99/// * Density (fraction of non-zero elements)
100#[allow(dead_code)]
101pub fn density(shape: (usize, usize), nnz: usize) -> f64 {
102    let (rows, cols) = shape;
103    if rows == 0 || cols == 0 {
104        return 0.0;
105    }
106
107    nnz as f64 / (rows * cols) as f64
108}
109
110/// Check if a sparse matrix is symmetric
111///
112/// # Arguments
113///
114/// * `matrix` - Sparse matrix to check
115///
116/// # Returns
117///
118/// * true if the matrix is symmetric, false otherwise
119#[allow(dead_code)]
120pub fn is_symmetric(matrix: &CsrMatrix<f64>) -> bool {
121    let (rows, cols) = matrix.shape();
122
123    // Must be square
124    if rows != cols {
125        return false;
126    }
127
128    // Check if A = A^T
129    let transposed = matrix.transpose();
130    let a_dense = matrix.to_dense();
131    let at_dense = transposed.to_dense();
132
133    for i in 0..rows {
134        for j in 0..cols {
135            if (a_dense[i][j] - at_dense[i][j]).abs() > 1e-10 {
136                return false;
137            }
138        }
139    }
140
141    true
142}
143
144/// Generate a random sparse matrix with given density
145///
146/// # Arguments
147///
148/// * `shape` - Matrix shape (rows, cols)
149/// * `density` - Desired density (0.0 to 1.0)
150///
151/// # Returns
152///
153/// * Random sparse matrix in CSR format
154#[allow(dead_code)]
155pub fn random(shape: (usize, usize), density: f64) -> SparseResult<CsrMatrix<f64>> {
156    if !(0.0..=1.0).contains(&density) {
157        return Err(SparseError::ValueError(format!(
158            "Density must be between 0 and 1, got {}",
159            density
160        )));
161    }
162
163    let (rows, cols) = shape;
164    if rows == 0 || cols == 0 {
165        return Err(SparseError::ValueError(
166            "Matrix dimensions must be positive".to_string(),
167        ));
168    }
169
170    // Calculate number of non-zero elements
171    let nnz = (rows * cols) as f64 * density;
172    let nnz = nnz.round() as usize;
173
174    if nnz == 0 {
175        // Return empty matrix
176        return CsrMatrix::new(Vec::new(), Vec::new(), Vec::new(), shape);
177    }
178
179    // Generate random non-zero elements
180    let mut data = Vec::with_capacity(nnz);
181    let mut row_indices = Vec::with_capacity(nnz);
182    let mut col_indices = Vec::with_capacity(nnz);
183
184    // Use a simple approach: randomly select nnz cells
185    // Note: this is not the most efficient approach for very sparse matrices
186    let mut used = vec![vec![false; cols]; rows];
187    let mut count = 0;
188
189    use rand::Rng;
190    let mut rng = rand::rng();
191
192    while count < nnz {
193        let i = rng.random_range(0..rows);
194        let j = rng.random_range(0..cols);
195
196        if !used[i][j] {
197            used[i][j] = true;
198            data.push(rng.random_range(-1.0..1.0));
199            row_indices.push(i);
200            col_indices.push(j);
201            count += 1;
202        }
203    }
204
205    CsrMatrix::new(data, row_indices, col_indices, shape)
206}
207
208/// Calculate the sparsity pattern of a matrix
209///
210/// # Arguments
211///
212/// * `matrix` - Sparse matrix
213///
214/// # Returns
215///
216/// * Vector of vectors representing the sparsity pattern (1 for non-zero, 0 for zero)
217#[allow(dead_code)]
218pub fn sparsity_pattern<T>(matrix: &CsrMatrix<T>) -> Vec<Vec<usize>>
219where
220    T: Clone + Copy + Zero + PartialEq,
221{
222    let (rows, cols) = matrix.shape();
223    let dense = matrix.to_dense();
224
225    let mut pattern = vec![vec![0; cols]; rows];
226
227    for i in 0..rows {
228        for j in 0..cols {
229            if dense[i][j] != T::zero() {
230                pattern[i][j] = 1;
231            }
232        }
233    }
234
235    pattern
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use approx::assert_relative_eq;
242
243    #[test]
244    fn test_identity() {
245        let n = 3;
246        let eye = identity(n).unwrap();
247
248        assert_eq!(eye.shape(), (n, n));
249        assert_eq!(eye.nnz(), n);
250
251        let dense = eye.to_dense();
252        for (i, row) in dense.iter().enumerate() {
253            for (j, &value) in row.iter().enumerate() {
254                let expected = if i == j { 1.0 } else { 0.0 };
255                assert_eq!(value, expected);
256            }
257        }
258    }
259
260    #[test]
261    fn test_diag() {
262        let diag_elements = [1.0, 2.0, 3.0];
263        let d = diag(&diag_elements).unwrap();
264
265        assert_eq!(d.shape(), (3, 3));
266        assert_eq!(d.nnz(), 3);
267
268        let dense = d.to_dense();
269        for i in 0..3 {
270            for j in 0..3 {
271                let expected = if i == j { diag_elements[i] } else { 0.0 };
272                assert_eq!(dense[i][j], expected);
273            }
274        }
275    }
276
277    #[test]
278    fn test_density() {
279        // Matrix with 25% non-zero elements
280        assert_relative_eq!(density((4, 4), 4), 0.25, epsilon = 1e-10);
281
282        // Empty matrix
283        assert_relative_eq!(density((10, 10), 0), 0.0, epsilon = 1e-10);
284
285        // Full matrix
286        assert_relative_eq!(density((5, 5), 25), 1.0, epsilon = 1e-10);
287    }
288
289    #[test]
290    fn test_is_symmetric() {
291        // Create a symmetric matrix
292        let rows = vec![0, 0, 1, 1, 2, 2];
293        let cols = vec![0, 1, 0, 1, 0, 2];
294        let data = vec![1.0, 2.0, 2.0, 3.0, 0.0, 4.0]; // Note: explicitly setting a zero
295        let shape = (3, 3);
296
297        let matrix = CsrMatrix::new(data, rows, cols, shape).unwrap();
298
299        // A symmetric matrix should have the same value at (i,j) and (j,i)
300        assert!(is_symmetric(&matrix));
301
302        // Create a non-symmetric matrix
303        let rows = vec![0, 0, 1, 1, 2, 2];
304        let cols = vec![0, 1, 0, 1, 0, 2];
305        let data = vec![1.0, 2.0, 3.0, 3.0, 0.0, 4.0]; // Changed 2.0 to 3.0
306        let shape = (3, 3);
307
308        let matrix = CsrMatrix::new(data, rows, cols, shape).unwrap();
309
310        assert!(!is_symmetric(&matrix));
311    }
312
313    #[test]
314    fn test_sparsity_pattern() {
315        // Create a sparse matrix
316        let rows = vec![0, 0, 1, 2, 2];
317        let cols = vec![0, 2, 2, 0, 1];
318        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
319        let shape = (3, 3);
320
321        let matrix = CsrMatrix::new(data, rows, cols, shape).unwrap();
322
323        // Calculate sparsity pattern
324        let pattern = sparsity_pattern(&matrix);
325
326        // Expected pattern:
327        // [1 0 1]
328        // [0 0 1]
329        // [1 1 0]
330        let expected = vec![vec![1, 0, 1], vec![0, 0, 1], vec![1, 1, 0]];
331
332        assert_eq!(pattern, expected);
333    }
334}