scirs2_sparse/sym_ops.rs
1// Optimized operations for symmetric sparse matrices
2//
3// This module provides specialized, optimized implementations of common
4// operations for symmetric sparse matrices, including matrix-vector products
5// and other computations that can take advantage of symmetry.
6
7use scirs2_core::ndarray::{Array1, ArrayView1};
8use scirs2_core::numeric::{Float, SparseElement};
9use std::fmt::Debug;
10use std::ops::{Add, Mul};
11
12use crate::error::SparseResult;
13use crate::sym_coo::SymCooMatrix;
14use crate::sym_csr::SymCsrMatrix;
15
16// Import parallel operations from scirs2-core
17use scirs2_core::parallel_ops::*;
18
19/// Computes a matrix-vector product for symmetric CSR matrices.
20///
21/// This function computes `y = A * x` where `A` is a symmetric matrix
22/// in CSR format, taking advantage of the symmetry. Only the lower (or upper)
23/// triangular part of the matrix is stored, but the full matrix is used
24/// in the computation.
25///
26/// # Arguments
27///
28/// * `matrix` - The symmetric matrix in CSR format
29/// * `x` - The input vector
30///
31/// # Returns
32///
33/// The result vector `y = A * x`
34///
35/// # Example
36///
37/// ```
38/// use scirs2_core::ndarray::Array1;
39/// use scirs2_sparse::sym_csr::SymCsrMatrix;
40/// use scirs2_sparse::sym_ops::sym_csr_matvec;
41///
42/// // Create a symmetric matrix
43/// let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
44/// let indices = vec![0, 0, 1, 1, 2];
45/// let indptr = vec![0, 1, 3, 5];
46/// let matrix = SymCsrMatrix::new(data, indptr, indices, (3, 3)).unwrap();
47///
48/// // Create a vector
49/// let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
50///
51/// // Compute the product
52/// let y = sym_csr_matvec(&matrix, &x.view()).unwrap();
53///
54/// // Verify the result: [2*1 + 1*2 + 0*3, 1*1 + 2*2 + 3*3, 0*1 + 3*2 + 1*3] = [4, 14, 9]
55/// assert_eq!(y[0], 4.0);
56/// assert_eq!(y[1], 14.0);
57/// assert_eq!(y[2], 9.0);
58/// ```
59#[allow(dead_code)]
60pub fn sym_csr_matvec<T>(matrix: &SymCsrMatrix<T>, x: &ArrayView1<T>) -> SparseResult<Array1<T>>
61where
62 T: Float + SparseElement + Debug + Copy + Add<Output = T> + Send + Sync,
63{
64 let (n, _) = matrix.shape();
65 if x.len() != n {
66 return Err(crate::error::SparseError::DimensionMismatch {
67 expected: n,
68 found: x.len(),
69 });
70 }
71
72 let nnz = matrix.nnz();
73
74 // Use parallel implementation for larger matrices
75 if nnz >= 1000 {
76 sym_csr_matvec_parallel(matrix, x)
77 } else {
78 sym_csr_matvec_scalar(matrix, x)
79 }
80}
81
82/// Parallel symmetric CSR matrix-vector multiplication
83#[allow(dead_code)]
84fn sym_csr_matvec_parallel<T>(
85 matrix: &SymCsrMatrix<T>,
86 x: &ArrayView1<T>,
87) -> SparseResult<Array1<T>>
88where
89 T: Float + SparseElement + Debug + Copy + Add<Output = T> + Send + Sync,
90{
91 let (n, _) = matrix.shape();
92 let mut y = Array1::zeros(n);
93
94 // Determine optimal chunk size based on matrix size
95 let chunk_size = std::cmp::max(1, n / scirs2_core::parallel_ops::get_num_threads()).min(256);
96
97 // Use scirs2-core parallel operations for better performance
98 let chunks: Vec<_> = (0..n)
99 .collect::<Vec<_>>()
100 .chunks(chunk_size)
101 .map(|chunk| chunk.to_vec())
102 .collect();
103
104 let results: Vec<_> = parallel_map(&chunks, |row_chunk| {
105 let mut local_y = Array1::zeros(n);
106
107 for &row_i in row_chunk {
108 let row_start = matrix.indptr[row_i];
109 let row_end = matrix.indptr[row_i + 1];
110
111 // Compute the dot product for this row
112 let mut sum = T::sparse_zero();
113 for j in row_start..row_end {
114 let col = matrix.indices[j];
115 let val = matrix.data[j];
116
117 sum = sum + val * x[col];
118
119 // For symmetric matrices, also add the symmetric contribution
120 // if we're below the diagonal
121 if row_i != col {
122 local_y[col] = local_y[col] + val * x[row_i];
123 }
124 }
125 local_y[row_i] = local_y[row_i] + sum;
126 }
127 local_y
128 });
129
130 // Combine results from all chunks (manual reduction since parallel_reduce not available)
131 for local_y in results {
132 for i in 0..n {
133 y[i] = y[i] + local_y[i];
134 }
135 }
136
137 Ok(y)
138}
139
140/// Scalar fallback version of symmetric CSR matrix-vector multiplication
141#[allow(dead_code)]
142fn sym_csr_matvec_scalar<T>(matrix: &SymCsrMatrix<T>, x: &ArrayView1<T>) -> SparseResult<Array1<T>>
143where
144 T: Float + SparseElement + Debug + Copy + Add<Output = T>,
145{
146 let (n, _) = matrix.shape();
147 let mut y = Array1::zeros(n);
148
149 // Standard scalar implementation
150 for i in 0..n {
151 for j in matrix.indptr[i]..matrix.indptr[i + 1] {
152 let col = matrix.indices[j];
153 let val = matrix.data[j];
154
155 y[i] = y[i] + val * x[col];
156
157 // If not on the diagonal, also update the upper triangular part
158 if i != col {
159 y[col] = y[col] + val * x[i];
160 }
161 }
162 }
163
164 Ok(y)
165}
166
167/// Computes a matrix-vector product for symmetric COO matrices.
168///
169/// This function computes `y = A * x` where `A` is a symmetric matrix
170/// in COO format, taking advantage of the symmetry. Only the lower (or upper)
171/// triangular part of the matrix is stored, but the full matrix is used
172/// in the computation.
173///
174/// # Arguments
175///
176/// * `matrix` - The symmetric matrix in COO format
177/// * `x` - The input vector
178///
179/// # Returns
180///
181/// The result vector `y = A * x`
182///
183/// # Example
184///
185/// ```
186/// use scirs2_core::ndarray::Array1;
187/// use scirs2_sparse::sym_coo::SymCooMatrix;
188/// use scirs2_sparse::sym_ops::sym_coo_matvec;
189///
190/// // Create a symmetric matrix
191/// let rows = vec![0, 1, 1, 2, 2];
192/// let cols = vec![0, 0, 1, 1, 2];
193/// let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
194/// let matrix = SymCooMatrix::new(data, rows, cols, (3, 3)).unwrap();
195///
196/// // Create a vector
197/// let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
198///
199/// // Compute the product
200/// let y = sym_coo_matvec(&matrix, &x.view()).unwrap();
201///
202/// // Verify the result: [2*1 + 1*2 + 0*3, 1*1 + 2*2 + 3*3, 0*1 + 3*2 + 1*3] = [4, 14, 9]
203/// assert_eq!(y[0], 4.0);
204/// assert_eq!(y[1], 14.0);
205/// assert_eq!(y[2], 9.0);
206/// ```
207#[allow(dead_code)]
208pub fn sym_coo_matvec<T>(matrix: &SymCooMatrix<T>, x: &ArrayView1<T>) -> SparseResult<Array1<T>>
209where
210 T: Float + SparseElement + Debug + Copy + Add<Output = T>,
211{
212 let (n, _) = matrix.shape();
213 if x.len() != n {
214 return Err(crate::error::SparseError::DimensionMismatch {
215 expected: n,
216 found: x.len(),
217 });
218 }
219
220 let mut y = Array1::zeros(n);
221
222 // Process each non-zero element in the lower triangular part
223 for i in 0..matrix.data.len() {
224 let row = matrix.rows[i];
225 let col = matrix.cols[i];
226 let val = matrix.data[i];
227
228 y[row] = y[row] + val * x[col];
229
230 // If not on the diagonal, also update the upper triangular part
231 if row != col {
232 y[col] = y[col] + val * x[row];
233 }
234 }
235
236 Ok(y)
237}
238
239/// Performs a symmetric rank-1 update of a symmetric CSR matrix.
240///
241/// This computes `A = A + alpha * x * x^T` where `A` is a symmetric matrix,
242/// `alpha` is a scalar, and `x` is a vector.
243///
244/// # Arguments
245///
246/// * `matrix` - The symmetric matrix to update (will be modified in-place)
247/// * `x` - The vector to use for the update
248/// * `alpha` - The scalar multiplier
249///
250/// # Returns
251///
252/// Result with `()` on success
253///
254/// # Note
255///
256/// This operation preserves symmetry but may change the sparsity pattern of the matrix.
257/// Currently only implemented for dense updates (all elements of x*x^T are considered).
258/// For sparse updates, additional optimizations would be possible.
259#[allow(dead_code)]
260pub fn sym_csr_rank1_update<T>(
261 matrix: &mut SymCsrMatrix<T>,
262 x: &ArrayView1<T>,
263 alpha: T,
264) -> SparseResult<()>
265where
266 T: Float
267 + SparseElement
268 + Debug
269 + Copy
270 + Add<Output = T>
271 + Mul<Output = T>
272 + std::ops::AddAssign,
273{
274 let (n, _) = matrix.shape();
275 if x.len() != n {
276 return Err(crate::error::SparseError::DimensionMismatch {
277 expected: n,
278 found: x.len(),
279 });
280 }
281
282 // For now, the easiest approach is to:
283 // 1. Convert to a dense matrix
284 // 2. Perform the rank-1 update
285 // 3. Convert back to symmetric CSR format
286
287 // Convert to dense
288 let mut dense = matrix.to_dense();
289
290 // Perform rank-1 update
291 for i in 0..n {
292 for j in 0..=i {
293 // Only update lower triangular (including diagonal)
294 let update = alpha * x[i] * x[j];
295 dense[i][j] += update;
296 }
297 }
298
299 // Convert back to CSR format (preserving symmetry)
300 let mut data = Vec::new();
301 let mut indices = Vec::new();
302 let mut indptr = vec![0];
303
304 for (i, row) in dense.iter().enumerate().take(n) {
305 for (j, &val) in row.iter().enumerate().take(i + 1) {
306 // Only include lower triangular (including diagonal)
307 if val != T::sparse_zero() {
308 data.push(val);
309 indices.push(j);
310 }
311 }
312 indptr.push(data.len());
313 }
314
315 // Replace the matrix data
316 matrix.data = data;
317 matrix.indices = indices;
318 matrix.indptr = indptr;
319
320 Ok(())
321}
322
323/// Calculates the quadratic form `x^T * A * x` for a symmetric matrix `A`.
324///
325/// This computation takes advantage of symmetry for efficiency.
326///
327/// # Arguments
328///
329/// * `matrix` - The symmetric matrix
330/// * `x` - The vector
331///
332/// # Returns
333///
334/// The scalar result of `x^T * A * x`
335///
336/// # Example
337///
338/// ```
339/// use scirs2_core::ndarray::Array1;
340/// use scirs2_sparse::sym_csr::SymCsrMatrix;
341/// use scirs2_sparse::sym_ops::sym_csr_quadratic_form;
342///
343/// // Create a symmetric matrix
344/// let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
345/// let indices = vec![0, 0, 1, 1, 2];
346/// let indptr = vec![0, 1, 3, 5];
347/// let matrix = SymCsrMatrix::new(data, indptr, indices, (3, 3)).unwrap();
348///
349/// // Create a vector
350/// let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
351///
352/// // Compute the quadratic form
353/// let result = sym_csr_quadratic_form(&matrix, &x.view()).unwrap();
354///
355/// // Verify: [1,2,3] * [2,1,0; 1,2,3; 0,3,1] * [1;2;3] = [1,2,3] * [4,14,9] = 4 + 28 + 27 = 59
356/// assert_eq!(result, 59.0);
357/// ```
358#[allow(dead_code)]
359pub fn sym_csr_quadratic_form<T>(matrix: &SymCsrMatrix<T>, x: &ArrayView1<T>) -> SparseResult<T>
360where
361 T: Float + SparseElement + Debug + Copy + Add<Output = T> + Mul<Output = T> + Send + Sync,
362{
363 // First compute A * x
364 let ax = sym_csr_matvec(matrix, x)?;
365
366 // Then compute x^T * (A * x)
367 let mut result = T::sparse_zero();
368 for i in 0..ax.len() {
369 result = result + x[i] * ax[i];
370 }
371
372 Ok(result)
373}
374
375/// Calculates the trace of a symmetric matrix.
376///
377/// The trace is the sum of the diagonal elements.
378///
379/// # Arguments
380///
381/// * `matrix` - The symmetric matrix
382///
383/// # Returns
384///
385/// The trace of the matrix
386///
387/// # Example
388///
389/// ```
390/// use scirs2_sparse::sym_csr::SymCsrMatrix;
391/// use scirs2_sparse::sym_ops::sym_csr_trace;
392///
393/// // Create a symmetric matrix
394/// let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
395/// let indices = vec![0, 0, 1, 1, 2];
396/// let indptr = vec![0, 1, 3, 5];
397/// let matrix = SymCsrMatrix::new(data, indptr, indices, (3, 3)).unwrap();
398///
399/// // Compute the trace
400/// let trace = sym_csr_trace(&matrix);
401///
402/// // Verify: 2 + 2 + 1 = 5
403/// assert_eq!(trace, 5.0);
404/// ```
405#[allow(dead_code)]
406pub fn sym_csr_trace<T>(matrix: &SymCsrMatrix<T>) -> T
407where
408 T: Float + SparseElement + Debug + Copy + Add<Output = T>,
409{
410 let (n, _) = matrix.shape();
411 let mut trace = T::sparse_zero();
412
413 // Sum the diagonal elements
414 for i in 0..n {
415 for j in matrix.indptr[i]..matrix.indptr[i + 1] {
416 let col = matrix.indices[j];
417 if col == i {
418 trace = trace + matrix.data[j];
419 break;
420 }
421 }
422 }
423
424 trace
425}
426
427#[cfg(test)]
428mod tests {
429 use super::*;
430 use crate::sym_coo::SymCooMatrix;
431 use crate::sym_csr::SymCsrMatrix;
432 use crate::AsLinearOperator; // For the test_compare_with_standard_matvec test
433 use approx::assert_relative_eq;
434 use scirs2_core::ndarray::Array1;
435
436 // Create a simple symmetric matrix for testing
437 fn create_test_sym_csr() -> SymCsrMatrix<f64> {
438 // Create a symmetric matrix:
439 // [2 1 0]
440 // [1 2 3]
441 // [0 3 1]
442
443 // Lower triangular part (which is stored):
444 // [2 0 0]
445 // [1 2 0]
446 // [0 3 1]
447
448 let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
449 let indices = vec![0, 0, 1, 1, 2];
450 let indptr = vec![0, 1, 3, 5];
451
452 SymCsrMatrix::new(data, indptr, indices, (3, 3)).unwrap()
453 }
454
455 // Create a simple symmetric matrix in COO format for testing
456 fn create_test_sym_coo() -> SymCooMatrix<f64> {
457 // Create a symmetric matrix:
458 // [2 1 0]
459 // [1 2 3]
460 // [0 3 1]
461
462 // Lower triangular part (which is stored):
463 // [2 0 0]
464 // [1 2 0]
465 // [0 3 1]
466
467 let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
468 let rows = vec![0, 1, 1, 2, 2];
469 let cols = vec![0, 0, 1, 1, 2];
470
471 SymCooMatrix::new(data, rows, cols, (3, 3)).unwrap()
472 }
473
474 #[test]
475 fn test_sym_csr_matvec() {
476 let matrix = create_test_sym_csr();
477 let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
478
479 let y = sym_csr_matvec(&matrix, &x.view()).unwrap();
480
481 // Expected result: [2*1 + 1*2 + 0*3, 1*1 + 2*2 + 3*3, 0*1 + 3*2 + 1*3] = [4, 14, 9]
482 assert_eq!(y.len(), 3);
483 assert_relative_eq!(y[0], 4.0);
484 assert_relative_eq!(y[1], 14.0);
485 assert_relative_eq!(y[2], 9.0);
486 }
487
488 #[test]
489 fn test_sym_coo_matvec() {
490 let matrix = create_test_sym_coo();
491 let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
492
493 let y = sym_coo_matvec(&matrix, &x.view()).unwrap();
494
495 // Expected result: [2*1 + 1*2 + 0*3, 1*1 + 2*2 + 3*3, 0*1 + 3*2 + 1*3] = [4, 14, 9]
496 assert_eq!(y.len(), 3);
497 assert_relative_eq!(y[0], 4.0);
498 assert_relative_eq!(y[1], 14.0);
499 assert_relative_eq!(y[2], 9.0);
500 }
501
502 #[test]
503 fn test_sym_csr_rank1_update() {
504 let mut matrix = create_test_sym_csr();
505 let x = Array1::from_vec(vec![1.0, 0.0, 0.0]);
506 let alpha = 3.0;
507
508 // Original diagonal element at (0,0) is 2.0
509 // After rank-1 update with [1,0,0] and alpha=3, it should be 2+3*1*1 = 5
510 sym_csr_rank1_update(&mut matrix, &x.view(), alpha).unwrap();
511
512 // Check the updated value
513 assert_relative_eq!(matrix.get(0, 0), 5.0);
514
515 // Other values should remain unchanged
516 assert_relative_eq!(matrix.get(0, 1), 1.0);
517 assert_relative_eq!(matrix.get(1, 1), 2.0);
518 assert_relative_eq!(matrix.get(1, 2), 3.0);
519 assert_relative_eq!(matrix.get(2, 2), 1.0);
520 }
521
522 #[test]
523 fn test_sym_csr_quadratic_form() {
524 let matrix = create_test_sym_csr();
525 let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
526
527 let result = sym_csr_quadratic_form(&matrix, &x.view()).unwrap();
528
529 // Expected result: [1,2,3] * [2,1,0; 1,2,3; 0,3,1] * [1;2;3]
530 // = [1,2,3] * [4,14,9] = 1*4 + 2*14 + 3*9 = 4 + 28 + 27 = 59
531 assert_relative_eq!(result, 59.0);
532 }
533
534 #[test]
535 fn test_sym_csr_trace() {
536 let matrix = create_test_sym_csr();
537
538 let trace = sym_csr_trace(&matrix);
539
540 // Expected: 2 + 2 + 1 = 5
541 assert_relative_eq!(trace, 5.0);
542 }
543
544 #[test]
545 fn test_compare_with_standard_matvec() {
546 // Create matrices and vectors
547 let sym_csr = create_test_sym_csr();
548 let full_csr = sym_csr.to_csr().unwrap();
549 let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
550
551 // Compute using the optimized function
552 let y_optimized = sym_csr_matvec(&sym_csr, &x.view()).unwrap();
553
554 // Compute using the standard function
555 let linear_op = full_csr.as_linear_operator();
556 let y_standard = linear_op.matvec(x.as_slice().unwrap()).unwrap();
557
558 // Compare results
559 for i in 0..y_optimized.len() {
560 assert_relative_eq!(y_optimized[i], y_standard[i]);
561 }
562 }
563}