scirs2_sparse/
sym_ops.rs

1// Optimized operations for symmetric sparse matrices
2//
3// This module provides specialized, optimized implementations of common
4// operations for symmetric sparse matrices, including matrix-vector products
5// and other computations that can take advantage of symmetry.
6
7use scirs2_core::ndarray::{Array1, ArrayView1};
8use scirs2_core::numeric::Float;
9use std::fmt::Debug;
10use std::ops::{Add, Mul};
11
12use crate::error::SparseResult;
13use crate::sym_coo::SymCooMatrix;
14use crate::sym_csr::SymCsrMatrix;
15
16// Import parallel operations from scirs2-core
17use scirs2_core::parallel_ops::*;
18
19/// Computes a matrix-vector product for symmetric CSR matrices.
20///
21/// This function computes `y = A * x` where `A` is a symmetric matrix
22/// in CSR format, taking advantage of the symmetry. Only the lower (or upper)
23/// triangular part of the matrix is stored, but the full matrix is used
24/// in the computation.
25///
26/// # Arguments
27///
28/// * `matrix` - The symmetric matrix in CSR format
29/// * `x` - The input vector
30///
31/// # Returns
32///
33/// The result vector `y = A * x`
34///
35/// # Example
36///
37/// ```
38/// use scirs2_core::ndarray::Array1;
39/// use scirs2_sparse::sym_csr::SymCsrMatrix;
40/// use scirs2_sparse::sym_ops::sym_csr_matvec;
41///
42/// // Create a symmetric matrix
43/// let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
44/// let indices = vec![0, 0, 1, 1, 2];
45/// let indptr = vec![0, 1, 3, 5];
46/// let matrix = SymCsrMatrix::new(data, indptr, indices, (3, 3)).unwrap();
47///
48/// // Create a vector
49/// let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
50///
51/// // Compute the product
52/// let y = sym_csr_matvec(&matrix, &x.view()).unwrap();
53///
54/// // Verify the result: [2*1 + 1*2 + 0*3, 1*1 + 2*2 + 3*3, 0*1 + 3*2 + 1*3] = [4, 14, 9]
55/// assert_eq!(y[0], 4.0);
56/// assert_eq!(y[1], 14.0);
57/// assert_eq!(y[2], 9.0);
58/// ```
59#[allow(dead_code)]
60pub fn sym_csr_matvec<T>(matrix: &SymCsrMatrix<T>, x: &ArrayView1<T>) -> SparseResult<Array1<T>>
61where
62    T: Float + Debug + Copy + Add<Output = T> + Send + Sync,
63{
64    let (n, _) = matrix.shape();
65    if x.len() != n {
66        return Err(crate::error::SparseError::DimensionMismatch {
67            expected: n,
68            found: x.len(),
69        });
70    }
71
72    let nnz = matrix.nnz();
73
74    // Use parallel implementation for larger matrices
75    if nnz >= 1000 {
76        sym_csr_matvec_parallel(matrix, x)
77    } else {
78        sym_csr_matvec_scalar(matrix, x)
79    }
80}
81
82/// Parallel symmetric CSR matrix-vector multiplication
83#[allow(dead_code)]
84fn sym_csr_matvec_parallel<T>(
85    matrix: &SymCsrMatrix<T>,
86    x: &ArrayView1<T>,
87) -> SparseResult<Array1<T>>
88where
89    T: Float + Debug + Copy + Add<Output = T> + Send + Sync,
90{
91    let (n, _) = matrix.shape();
92    let mut y = Array1::zeros(n);
93
94    // Determine optimal chunk size based on matrix size
95    let chunk_size = std::cmp::max(1, n / scirs2_core::parallel_ops::get_num_threads()).min(256);
96
97    // Use scirs2-core parallel operations for better performance
98    let chunks: Vec<_> = (0..n)
99        .collect::<Vec<_>>()
100        .chunks(chunk_size)
101        .map(|chunk| chunk.to_vec())
102        .collect();
103
104    let results: Vec<_> = parallel_map(&chunks, |row_chunk| {
105        let mut local_y = Array1::zeros(n);
106
107        for &row_i in row_chunk {
108            let row_start = matrix.indptr[row_i];
109            let row_end = matrix.indptr[row_i + 1];
110
111            // Compute the dot product for this row
112            let mut sum = T::zero();
113            for j in row_start..row_end {
114                let col = matrix.indices[j];
115                let val = matrix.data[j];
116
117                sum = sum + val * x[col];
118
119                // For symmetric matrices, also add the symmetric contribution
120                // if we're below the diagonal
121                if row_i != col {
122                    local_y[col] = local_y[col] + val * x[row_i];
123                }
124            }
125            local_y[row_i] = local_y[row_i] + sum;
126        }
127        local_y
128    });
129
130    // Combine results from all chunks (manual reduction since parallel_reduce not available)
131    for local_y in results {
132        for i in 0..n {
133            y[i] = y[i] + local_y[i];
134        }
135    }
136
137    Ok(y)
138}
139
140/// Scalar fallback version of symmetric CSR matrix-vector multiplication
141#[allow(dead_code)]
142fn sym_csr_matvec_scalar<T>(matrix: &SymCsrMatrix<T>, x: &ArrayView1<T>) -> SparseResult<Array1<T>>
143where
144    T: Float + Debug + Copy + Add<Output = T>,
145{
146    let (n, _) = matrix.shape();
147    let mut y = Array1::zeros(n);
148
149    // Standard scalar implementation
150    for i in 0..n {
151        for j in matrix.indptr[i]..matrix.indptr[i + 1] {
152            let col = matrix.indices[j];
153            let val = matrix.data[j];
154
155            y[i] = y[i] + val * x[col];
156
157            // If not on the diagonal, also update the upper triangular part
158            if i != col {
159                y[col] = y[col] + val * x[i];
160            }
161        }
162    }
163
164    Ok(y)
165}
166
167/// Computes a matrix-vector product for symmetric COO matrices.
168///
169/// This function computes `y = A * x` where `A` is a symmetric matrix
170/// in COO format, taking advantage of the symmetry. Only the lower (or upper)
171/// triangular part of the matrix is stored, but the full matrix is used
172/// in the computation.
173///
174/// # Arguments
175///
176/// * `matrix` - The symmetric matrix in COO format
177/// * `x` - The input vector
178///
179/// # Returns
180///
181/// The result vector `y = A * x`
182///
183/// # Example
184///
185/// ```
186/// use scirs2_core::ndarray::Array1;
187/// use scirs2_sparse::sym_coo::SymCooMatrix;
188/// use scirs2_sparse::sym_ops::sym_coo_matvec;
189///
190/// // Create a symmetric matrix
191/// let rows = vec![0, 1, 1, 2, 2];
192/// let cols = vec![0, 0, 1, 1, 2];
193/// let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
194/// let matrix = SymCooMatrix::new(data, rows, cols, (3, 3)).unwrap();
195///
196/// // Create a vector
197/// let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
198///
199/// // Compute the product
200/// let y = sym_coo_matvec(&matrix, &x.view()).unwrap();
201///
202/// // Verify the result: [2*1 + 1*2 + 0*3, 1*1 + 2*2 + 3*3, 0*1 + 3*2 + 1*3] = [4, 14, 9]
203/// assert_eq!(y[0], 4.0);
204/// assert_eq!(y[1], 14.0);
205/// assert_eq!(y[2], 9.0);
206/// ```
207#[allow(dead_code)]
208pub fn sym_coo_matvec<T>(matrix: &SymCooMatrix<T>, x: &ArrayView1<T>) -> SparseResult<Array1<T>>
209where
210    T: Float + Debug + Copy + Add<Output = T>,
211{
212    let (n, _) = matrix.shape();
213    if x.len() != n {
214        return Err(crate::error::SparseError::DimensionMismatch {
215            expected: n,
216            found: x.len(),
217        });
218    }
219
220    let mut y = Array1::zeros(n);
221
222    // Process each non-zero element in the lower triangular part
223    for i in 0..matrix.data.len() {
224        let row = matrix.rows[i];
225        let col = matrix.cols[i];
226        let val = matrix.data[i];
227
228        y[row] = y[row] + val * x[col];
229
230        // If not on the diagonal, also update the upper triangular part
231        if row != col {
232            y[col] = y[col] + val * x[row];
233        }
234    }
235
236    Ok(y)
237}
238
239/// Performs a symmetric rank-1 update of a symmetric CSR matrix.
240///
241/// This computes `A = A + alpha * x * x^T` where `A` is a symmetric matrix,
242/// `alpha` is a scalar, and `x` is a vector.
243///
244/// # Arguments
245///
246/// * `matrix` - The symmetric matrix to update (will be modified in-place)
247/// * `x` - The vector to use for the update
248/// * `alpha` - The scalar multiplier
249///
250/// # Returns
251///
252/// Result with `()` on success
253///
254/// # Note
255///
256/// This operation preserves symmetry but may change the sparsity pattern of the matrix.
257/// Currently only implemented for dense updates (all elements of x*x^T are considered).
258/// For sparse updates, additional optimizations would be possible.
259#[allow(dead_code)]
260pub fn sym_csr_rank1_update<T>(
261    matrix: &mut SymCsrMatrix<T>,
262    x: &ArrayView1<T>,
263    alpha: T,
264) -> SparseResult<()>
265where
266    T: Float + Debug + Copy + Add<Output = T> + Mul<Output = T> + std::ops::AddAssign,
267{
268    let (n, _) = matrix.shape();
269    if x.len() != n {
270        return Err(crate::error::SparseError::DimensionMismatch {
271            expected: n,
272            found: x.len(),
273        });
274    }
275
276    // For now, the easiest approach is to:
277    // 1. Convert to a dense matrix
278    // 2. Perform the rank-1 update
279    // 3. Convert back to symmetric CSR format
280
281    // Convert to dense
282    let mut dense = matrix.to_dense();
283
284    // Perform rank-1 update
285    for i in 0..n {
286        for j in 0..=i {
287            // Only update lower triangular (including diagonal)
288            let update = alpha * x[i] * x[j];
289            dense[i][j] += update;
290        }
291    }
292
293    // Convert back to CSR format (preserving symmetry)
294    let mut data = Vec::new();
295    let mut indices = Vec::new();
296    let mut indptr = vec![0];
297
298    for (i, row) in dense.iter().enumerate().take(n) {
299        for (j, &val) in row.iter().enumerate().take(i + 1) {
300            // Only include lower triangular (including diagonal)
301            if val != T::zero() {
302                data.push(val);
303                indices.push(j);
304            }
305        }
306        indptr.push(data.len());
307    }
308
309    // Replace the matrix data
310    matrix.data = data;
311    matrix.indices = indices;
312    matrix.indptr = indptr;
313
314    Ok(())
315}
316
317/// Calculates the quadratic form `x^T * A * x` for a symmetric matrix `A`.
318///
319/// This computation takes advantage of symmetry for efficiency.
320///
321/// # Arguments
322///
323/// * `matrix` - The symmetric matrix
324/// * `x` - The vector
325///
326/// # Returns
327///
328/// The scalar result of `x^T * A * x`
329///
330/// # Example
331///
332/// ```
333/// use scirs2_core::ndarray::Array1;
334/// use scirs2_sparse::sym_csr::SymCsrMatrix;
335/// use scirs2_sparse::sym_ops::sym_csr_quadratic_form;
336///
337/// // Create a symmetric matrix
338/// let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
339/// let indices = vec![0, 0, 1, 1, 2];
340/// let indptr = vec![0, 1, 3, 5];
341/// let matrix = SymCsrMatrix::new(data, indptr, indices, (3, 3)).unwrap();
342///
343/// // Create a vector
344/// let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
345///
346/// // Compute the quadratic form
347/// let result = sym_csr_quadratic_form(&matrix, &x.view()).unwrap();
348///
349/// // Verify: [1,2,3] * [2,1,0; 1,2,3; 0,3,1] * [1;2;3] = [1,2,3] * [4,14,9] = 4 + 28 + 27 = 59
350/// assert_eq!(result, 59.0);
351/// ```
352#[allow(dead_code)]
353pub fn sym_csr_quadratic_form<T>(matrix: &SymCsrMatrix<T>, x: &ArrayView1<T>) -> SparseResult<T>
354where
355    T: Float + Debug + Copy + Add<Output = T> + Mul<Output = T> + Send + Sync,
356{
357    // First compute A * x
358    let ax = sym_csr_matvec(matrix, x)?;
359
360    // Then compute x^T * (A * x)
361    let mut result = T::zero();
362    for i in 0..ax.len() {
363        result = result + x[i] * ax[i];
364    }
365
366    Ok(result)
367}
368
369/// Calculates the trace of a symmetric matrix.
370///
371/// The trace is the sum of the diagonal elements.
372///
373/// # Arguments
374///
375/// * `matrix` - The symmetric matrix
376///
377/// # Returns
378///
379/// The trace of the matrix
380///
381/// # Example
382///
383/// ```
384/// use scirs2_sparse::sym_csr::SymCsrMatrix;
385/// use scirs2_sparse::sym_ops::sym_csr_trace;
386///
387/// // Create a symmetric matrix
388/// let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
389/// let indices = vec![0, 0, 1, 1, 2];
390/// let indptr = vec![0, 1, 3, 5];
391/// let matrix = SymCsrMatrix::new(data, indptr, indices, (3, 3)).unwrap();
392///
393/// // Compute the trace
394/// let trace = sym_csr_trace(&matrix);
395///
396/// // Verify: 2 + 2 + 1 = 5
397/// assert_eq!(trace, 5.0);
398/// ```
399#[allow(dead_code)]
400pub fn sym_csr_trace<T>(matrix: &SymCsrMatrix<T>) -> T
401where
402    T: Float + Debug + Copy + Add<Output = T>,
403{
404    let (n, _) = matrix.shape();
405    let mut trace = T::zero();
406
407    // Sum the diagonal elements
408    for i in 0..n {
409        for j in matrix.indptr[i]..matrix.indptr[i + 1] {
410            let col = matrix.indices[j];
411            if col == i {
412                trace = trace + matrix.data[j];
413                break;
414            }
415        }
416    }
417
418    trace
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424    use crate::sym_coo::SymCooMatrix;
425    use crate::sym_csr::SymCsrMatrix;
426    use crate::AsLinearOperator; // For the test_compare_with_standard_matvec test
427    use approx::assert_relative_eq;
428    use scirs2_core::ndarray::Array1;
429
430    // Create a simple symmetric matrix for testing
431    fn create_test_sym_csr() -> SymCsrMatrix<f64> {
432        // Create a symmetric matrix:
433        // [2 1 0]
434        // [1 2 3]
435        // [0 3 1]
436
437        // Lower triangular part (which is stored):
438        // [2 0 0]
439        // [1 2 0]
440        // [0 3 1]
441
442        let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
443        let indices = vec![0, 0, 1, 1, 2];
444        let indptr = vec![0, 1, 3, 5];
445
446        SymCsrMatrix::new(data, indptr, indices, (3, 3)).unwrap()
447    }
448
449    // Create a simple symmetric matrix in COO format for testing
450    fn create_test_sym_coo() -> SymCooMatrix<f64> {
451        // Create a symmetric matrix:
452        // [2 1 0]
453        // [1 2 3]
454        // [0 3 1]
455
456        // Lower triangular part (which is stored):
457        // [2 0 0]
458        // [1 2 0]
459        // [0 3 1]
460
461        let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
462        let rows = vec![0, 1, 1, 2, 2];
463        let cols = vec![0, 0, 1, 1, 2];
464
465        SymCooMatrix::new(data, rows, cols, (3, 3)).unwrap()
466    }
467
468    #[test]
469    fn test_sym_csr_matvec() {
470        let matrix = create_test_sym_csr();
471        let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
472
473        let y = sym_csr_matvec(&matrix, &x.view()).unwrap();
474
475        // Expected result: [2*1 + 1*2 + 0*3, 1*1 + 2*2 + 3*3, 0*1 + 3*2 + 1*3] = [4, 14, 9]
476        assert_eq!(y.len(), 3);
477        assert_relative_eq!(y[0], 4.0);
478        assert_relative_eq!(y[1], 14.0);
479        assert_relative_eq!(y[2], 9.0);
480    }
481
482    #[test]
483    fn test_sym_coo_matvec() {
484        let matrix = create_test_sym_coo();
485        let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
486
487        let y = sym_coo_matvec(&matrix, &x.view()).unwrap();
488
489        // Expected result: [2*1 + 1*2 + 0*3, 1*1 + 2*2 + 3*3, 0*1 + 3*2 + 1*3] = [4, 14, 9]
490        assert_eq!(y.len(), 3);
491        assert_relative_eq!(y[0], 4.0);
492        assert_relative_eq!(y[1], 14.0);
493        assert_relative_eq!(y[2], 9.0);
494    }
495
496    #[test]
497    fn test_sym_csr_rank1_update() {
498        let mut matrix = create_test_sym_csr();
499        let x = Array1::from_vec(vec![1.0, 0.0, 0.0]);
500        let alpha = 3.0;
501
502        // Original diagonal element at (0,0) is 2.0
503        // After rank-1 update with [1,0,0] and alpha=3, it should be 2+3*1*1 = 5
504        sym_csr_rank1_update(&mut matrix, &x.view(), alpha).unwrap();
505
506        // Check the updated value
507        assert_relative_eq!(matrix.get(0, 0), 5.0);
508
509        // Other values should remain unchanged
510        assert_relative_eq!(matrix.get(0, 1), 1.0);
511        assert_relative_eq!(matrix.get(1, 1), 2.0);
512        assert_relative_eq!(matrix.get(1, 2), 3.0);
513        assert_relative_eq!(matrix.get(2, 2), 1.0);
514    }
515
516    #[test]
517    fn test_sym_csr_quadratic_form() {
518        let matrix = create_test_sym_csr();
519        let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
520
521        let result = sym_csr_quadratic_form(&matrix, &x.view()).unwrap();
522
523        // Expected result: [1,2,3] * [2,1,0; 1,2,3; 0,3,1] * [1;2;3]
524        // = [1,2,3] * [4,14,9] = 1*4 + 2*14 + 3*9 = 4 + 28 + 27 = 59
525        assert_relative_eq!(result, 59.0);
526    }
527
528    #[test]
529    fn test_sym_csr_trace() {
530        let matrix = create_test_sym_csr();
531
532        let trace = sym_csr_trace(&matrix);
533
534        // Expected: 2 + 2 + 1 = 5
535        assert_relative_eq!(trace, 5.0);
536    }
537
538    #[test]
539    fn test_compare_with_standard_matvec() {
540        // Create matrices and vectors
541        let sym_csr = create_test_sym_csr();
542        let full_csr = sym_csr.to_csr().unwrap();
543        let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
544
545        // Compute using the optimized function
546        let y_optimized = sym_csr_matvec(&sym_csr, &x.view()).unwrap();
547
548        // Compute using the standard function
549        let linear_op = full_csr.as_linear_operator();
550        let y_standard = linear_op.matvec(x.as_slice().unwrap()).unwrap();
551
552        // Compare results
553        for i in 0..y_optimized.len() {
554            assert_relative_eq!(y_optimized[i], y_standard[i]);
555        }
556    }
557}