scirs2_sparse/sym_ops.rs
1// Optimized operations for symmetric sparse matrices
2//
3// This module provides specialized, optimized implementations of common
4// operations for symmetric sparse matrices, including matrix-vector products
5// and other computations that can take advantage of symmetry.
6
7use scirs2_core::ndarray::{Array1, ArrayView1};
8use scirs2_core::numeric::Float;
9use std::fmt::Debug;
10use std::ops::{Add, Mul};
11
12use crate::error::SparseResult;
13use crate::sym_coo::SymCooMatrix;
14use crate::sym_csr::SymCsrMatrix;
15
16// Import parallel operations from scirs2-core
17use scirs2_core::parallel_ops::*;
18
19/// Computes a matrix-vector product for symmetric CSR matrices.
20///
21/// This function computes `y = A * x` where `A` is a symmetric matrix
22/// in CSR format, taking advantage of the symmetry. Only the lower (or upper)
23/// triangular part of the matrix is stored, but the full matrix is used
24/// in the computation.
25///
26/// # Arguments
27///
28/// * `matrix` - The symmetric matrix in CSR format
29/// * `x` - The input vector
30///
31/// # Returns
32///
33/// The result vector `y = A * x`
34///
35/// # Example
36///
37/// ```
38/// use scirs2_core::ndarray::Array1;
39/// use scirs2_sparse::sym_csr::SymCsrMatrix;
40/// use scirs2_sparse::sym_ops::sym_csr_matvec;
41///
42/// // Create a symmetric matrix
43/// let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
44/// let indices = vec![0, 0, 1, 1, 2];
45/// let indptr = vec![0, 1, 3, 5];
46/// let matrix = SymCsrMatrix::new(data, indptr, indices, (3, 3)).unwrap();
47///
48/// // Create a vector
49/// let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
50///
51/// // Compute the product
52/// let y = sym_csr_matvec(&matrix, &x.view()).unwrap();
53///
54/// // Verify the result: [2*1 + 1*2 + 0*3, 1*1 + 2*2 + 3*3, 0*1 + 3*2 + 1*3] = [4, 14, 9]
55/// assert_eq!(y[0], 4.0);
56/// assert_eq!(y[1], 14.0);
57/// assert_eq!(y[2], 9.0);
58/// ```
59#[allow(dead_code)]
60pub fn sym_csr_matvec<T>(matrix: &SymCsrMatrix<T>, x: &ArrayView1<T>) -> SparseResult<Array1<T>>
61where
62 T: Float + Debug + Copy + Add<Output = T> + Send + Sync,
63{
64 let (n, _) = matrix.shape();
65 if x.len() != n {
66 return Err(crate::error::SparseError::DimensionMismatch {
67 expected: n,
68 found: x.len(),
69 });
70 }
71
72 let nnz = matrix.nnz();
73
74 // Use parallel implementation for larger matrices
75 if nnz >= 1000 {
76 sym_csr_matvec_parallel(matrix, x)
77 } else {
78 sym_csr_matvec_scalar(matrix, x)
79 }
80}
81
82/// Parallel symmetric CSR matrix-vector multiplication
83#[allow(dead_code)]
84fn sym_csr_matvec_parallel<T>(
85 matrix: &SymCsrMatrix<T>,
86 x: &ArrayView1<T>,
87) -> SparseResult<Array1<T>>
88where
89 T: Float + Debug + Copy + Add<Output = T> + Send + Sync,
90{
91 let (n, _) = matrix.shape();
92 let mut y = Array1::zeros(n);
93
94 // Determine optimal chunk size based on matrix size
95 let chunk_size = std::cmp::max(1, n / scirs2_core::parallel_ops::get_num_threads()).min(256);
96
97 // Use scirs2-core parallel operations for better performance
98 let chunks: Vec<_> = (0..n)
99 .collect::<Vec<_>>()
100 .chunks(chunk_size)
101 .map(|chunk| chunk.to_vec())
102 .collect();
103
104 let results: Vec<_> = parallel_map(&chunks, |row_chunk| {
105 let mut local_y = Array1::zeros(n);
106
107 for &row_i in row_chunk {
108 let row_start = matrix.indptr[row_i];
109 let row_end = matrix.indptr[row_i + 1];
110
111 // Compute the dot product for this row
112 let mut sum = T::zero();
113 for j in row_start..row_end {
114 let col = matrix.indices[j];
115 let val = matrix.data[j];
116
117 sum = sum + val * x[col];
118
119 // For symmetric matrices, also add the symmetric contribution
120 // if we're below the diagonal
121 if row_i != col {
122 local_y[col] = local_y[col] + val * x[row_i];
123 }
124 }
125 local_y[row_i] = local_y[row_i] + sum;
126 }
127 local_y
128 });
129
130 // Combine results from all chunks (manual reduction since parallel_reduce not available)
131 for local_y in results {
132 for i in 0..n {
133 y[i] = y[i] + local_y[i];
134 }
135 }
136
137 Ok(y)
138}
139
140/// Scalar fallback version of symmetric CSR matrix-vector multiplication
141#[allow(dead_code)]
142fn sym_csr_matvec_scalar<T>(matrix: &SymCsrMatrix<T>, x: &ArrayView1<T>) -> SparseResult<Array1<T>>
143where
144 T: Float + Debug + Copy + Add<Output = T>,
145{
146 let (n, _) = matrix.shape();
147 let mut y = Array1::zeros(n);
148
149 // Standard scalar implementation
150 for i in 0..n {
151 for j in matrix.indptr[i]..matrix.indptr[i + 1] {
152 let col = matrix.indices[j];
153 let val = matrix.data[j];
154
155 y[i] = y[i] + val * x[col];
156
157 // If not on the diagonal, also update the upper triangular part
158 if i != col {
159 y[col] = y[col] + val * x[i];
160 }
161 }
162 }
163
164 Ok(y)
165}
166
167/// Computes a matrix-vector product for symmetric COO matrices.
168///
169/// This function computes `y = A * x` where `A` is a symmetric matrix
170/// in COO format, taking advantage of the symmetry. Only the lower (or upper)
171/// triangular part of the matrix is stored, but the full matrix is used
172/// in the computation.
173///
174/// # Arguments
175///
176/// * `matrix` - The symmetric matrix in COO format
177/// * `x` - The input vector
178///
179/// # Returns
180///
181/// The result vector `y = A * x`
182///
183/// # Example
184///
185/// ```
186/// use scirs2_core::ndarray::Array1;
187/// use scirs2_sparse::sym_coo::SymCooMatrix;
188/// use scirs2_sparse::sym_ops::sym_coo_matvec;
189///
190/// // Create a symmetric matrix
191/// let rows = vec![0, 1, 1, 2, 2];
192/// let cols = vec![0, 0, 1, 1, 2];
193/// let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
194/// let matrix = SymCooMatrix::new(data, rows, cols, (3, 3)).unwrap();
195///
196/// // Create a vector
197/// let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
198///
199/// // Compute the product
200/// let y = sym_coo_matvec(&matrix, &x.view()).unwrap();
201///
202/// // Verify the result: [2*1 + 1*2 + 0*3, 1*1 + 2*2 + 3*3, 0*1 + 3*2 + 1*3] = [4, 14, 9]
203/// assert_eq!(y[0], 4.0);
204/// assert_eq!(y[1], 14.0);
205/// assert_eq!(y[2], 9.0);
206/// ```
207#[allow(dead_code)]
208pub fn sym_coo_matvec<T>(matrix: &SymCooMatrix<T>, x: &ArrayView1<T>) -> SparseResult<Array1<T>>
209where
210 T: Float + Debug + Copy + Add<Output = T>,
211{
212 let (n, _) = matrix.shape();
213 if x.len() != n {
214 return Err(crate::error::SparseError::DimensionMismatch {
215 expected: n,
216 found: x.len(),
217 });
218 }
219
220 let mut y = Array1::zeros(n);
221
222 // Process each non-zero element in the lower triangular part
223 for i in 0..matrix.data.len() {
224 let row = matrix.rows[i];
225 let col = matrix.cols[i];
226 let val = matrix.data[i];
227
228 y[row] = y[row] + val * x[col];
229
230 // If not on the diagonal, also update the upper triangular part
231 if row != col {
232 y[col] = y[col] + val * x[row];
233 }
234 }
235
236 Ok(y)
237}
238
239/// Performs a symmetric rank-1 update of a symmetric CSR matrix.
240///
241/// This computes `A = A + alpha * x * x^T` where `A` is a symmetric matrix,
242/// `alpha` is a scalar, and `x` is a vector.
243///
244/// # Arguments
245///
246/// * `matrix` - The symmetric matrix to update (will be modified in-place)
247/// * `x` - The vector to use for the update
248/// * `alpha` - The scalar multiplier
249///
250/// # Returns
251///
252/// Result with `()` on success
253///
254/// # Note
255///
256/// This operation preserves symmetry but may change the sparsity pattern of the matrix.
257/// Currently only implemented for dense updates (all elements of x*x^T are considered).
258/// For sparse updates, additional optimizations would be possible.
259#[allow(dead_code)]
260pub fn sym_csr_rank1_update<T>(
261 matrix: &mut SymCsrMatrix<T>,
262 x: &ArrayView1<T>,
263 alpha: T,
264) -> SparseResult<()>
265where
266 T: Float + Debug + Copy + Add<Output = T> + Mul<Output = T> + std::ops::AddAssign,
267{
268 let (n, _) = matrix.shape();
269 if x.len() != n {
270 return Err(crate::error::SparseError::DimensionMismatch {
271 expected: n,
272 found: x.len(),
273 });
274 }
275
276 // For now, the easiest approach is to:
277 // 1. Convert to a dense matrix
278 // 2. Perform the rank-1 update
279 // 3. Convert back to symmetric CSR format
280
281 // Convert to dense
282 let mut dense = matrix.to_dense();
283
284 // Perform rank-1 update
285 for i in 0..n {
286 for j in 0..=i {
287 // Only update lower triangular (including diagonal)
288 let update = alpha * x[i] * x[j];
289 dense[i][j] += update;
290 }
291 }
292
293 // Convert back to CSR format (preserving symmetry)
294 let mut data = Vec::new();
295 let mut indices = Vec::new();
296 let mut indptr = vec![0];
297
298 for (i, row) in dense.iter().enumerate().take(n) {
299 for (j, &val) in row.iter().enumerate().take(i + 1) {
300 // Only include lower triangular (including diagonal)
301 if val != T::zero() {
302 data.push(val);
303 indices.push(j);
304 }
305 }
306 indptr.push(data.len());
307 }
308
309 // Replace the matrix data
310 matrix.data = data;
311 matrix.indices = indices;
312 matrix.indptr = indptr;
313
314 Ok(())
315}
316
317/// Calculates the quadratic form `x^T * A * x` for a symmetric matrix `A`.
318///
319/// This computation takes advantage of symmetry for efficiency.
320///
321/// # Arguments
322///
323/// * `matrix` - The symmetric matrix
324/// * `x` - The vector
325///
326/// # Returns
327///
328/// The scalar result of `x^T * A * x`
329///
330/// # Example
331///
332/// ```
333/// use scirs2_core::ndarray::Array1;
334/// use scirs2_sparse::sym_csr::SymCsrMatrix;
335/// use scirs2_sparse::sym_ops::sym_csr_quadratic_form;
336///
337/// // Create a symmetric matrix
338/// let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
339/// let indices = vec![0, 0, 1, 1, 2];
340/// let indptr = vec![0, 1, 3, 5];
341/// let matrix = SymCsrMatrix::new(data, indptr, indices, (3, 3)).unwrap();
342///
343/// // Create a vector
344/// let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
345///
346/// // Compute the quadratic form
347/// let result = sym_csr_quadratic_form(&matrix, &x.view()).unwrap();
348///
349/// // Verify: [1,2,3] * [2,1,0; 1,2,3; 0,3,1] * [1;2;3] = [1,2,3] * [4,14,9] = 4 + 28 + 27 = 59
350/// assert_eq!(result, 59.0);
351/// ```
352#[allow(dead_code)]
353pub fn sym_csr_quadratic_form<T>(matrix: &SymCsrMatrix<T>, x: &ArrayView1<T>) -> SparseResult<T>
354where
355 T: Float + Debug + Copy + Add<Output = T> + Mul<Output = T> + Send + Sync,
356{
357 // First compute A * x
358 let ax = sym_csr_matvec(matrix, x)?;
359
360 // Then compute x^T * (A * x)
361 let mut result = T::zero();
362 for i in 0..ax.len() {
363 result = result + x[i] * ax[i];
364 }
365
366 Ok(result)
367}
368
369/// Calculates the trace of a symmetric matrix.
370///
371/// The trace is the sum of the diagonal elements.
372///
373/// # Arguments
374///
375/// * `matrix` - The symmetric matrix
376///
377/// # Returns
378///
379/// The trace of the matrix
380///
381/// # Example
382///
383/// ```
384/// use scirs2_sparse::sym_csr::SymCsrMatrix;
385/// use scirs2_sparse::sym_ops::sym_csr_trace;
386///
387/// // Create a symmetric matrix
388/// let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
389/// let indices = vec![0, 0, 1, 1, 2];
390/// let indptr = vec![0, 1, 3, 5];
391/// let matrix = SymCsrMatrix::new(data, indptr, indices, (3, 3)).unwrap();
392///
393/// // Compute the trace
394/// let trace = sym_csr_trace(&matrix);
395///
396/// // Verify: 2 + 2 + 1 = 5
397/// assert_eq!(trace, 5.0);
398/// ```
399#[allow(dead_code)]
400pub fn sym_csr_trace<T>(matrix: &SymCsrMatrix<T>) -> T
401where
402 T: Float + Debug + Copy + Add<Output = T>,
403{
404 let (n, _) = matrix.shape();
405 let mut trace = T::zero();
406
407 // Sum the diagonal elements
408 for i in 0..n {
409 for j in matrix.indptr[i]..matrix.indptr[i + 1] {
410 let col = matrix.indices[j];
411 if col == i {
412 trace = trace + matrix.data[j];
413 break;
414 }
415 }
416 }
417
418 trace
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424 use crate::sym_coo::SymCooMatrix;
425 use crate::sym_csr::SymCsrMatrix;
426 use crate::AsLinearOperator; // For the test_compare_with_standard_matvec test
427 use approx::assert_relative_eq;
428 use scirs2_core::ndarray::Array1;
429
430 // Create a simple symmetric matrix for testing
431 fn create_test_sym_csr() -> SymCsrMatrix<f64> {
432 // Create a symmetric matrix:
433 // [2 1 0]
434 // [1 2 3]
435 // [0 3 1]
436
437 // Lower triangular part (which is stored):
438 // [2 0 0]
439 // [1 2 0]
440 // [0 3 1]
441
442 let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
443 let indices = vec![0, 0, 1, 1, 2];
444 let indptr = vec![0, 1, 3, 5];
445
446 SymCsrMatrix::new(data, indptr, indices, (3, 3)).unwrap()
447 }
448
449 // Create a simple symmetric matrix in COO format for testing
450 fn create_test_sym_coo() -> SymCooMatrix<f64> {
451 // Create a symmetric matrix:
452 // [2 1 0]
453 // [1 2 3]
454 // [0 3 1]
455
456 // Lower triangular part (which is stored):
457 // [2 0 0]
458 // [1 2 0]
459 // [0 3 1]
460
461 let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
462 let rows = vec![0, 1, 1, 2, 2];
463 let cols = vec![0, 0, 1, 1, 2];
464
465 SymCooMatrix::new(data, rows, cols, (3, 3)).unwrap()
466 }
467
468 #[test]
469 fn test_sym_csr_matvec() {
470 let matrix = create_test_sym_csr();
471 let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
472
473 let y = sym_csr_matvec(&matrix, &x.view()).unwrap();
474
475 // Expected result: [2*1 + 1*2 + 0*3, 1*1 + 2*2 + 3*3, 0*1 + 3*2 + 1*3] = [4, 14, 9]
476 assert_eq!(y.len(), 3);
477 assert_relative_eq!(y[0], 4.0);
478 assert_relative_eq!(y[1], 14.0);
479 assert_relative_eq!(y[2], 9.0);
480 }
481
482 #[test]
483 fn test_sym_coo_matvec() {
484 let matrix = create_test_sym_coo();
485 let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
486
487 let y = sym_coo_matvec(&matrix, &x.view()).unwrap();
488
489 // Expected result: [2*1 + 1*2 + 0*3, 1*1 + 2*2 + 3*3, 0*1 + 3*2 + 1*3] = [4, 14, 9]
490 assert_eq!(y.len(), 3);
491 assert_relative_eq!(y[0], 4.0);
492 assert_relative_eq!(y[1], 14.0);
493 assert_relative_eq!(y[2], 9.0);
494 }
495
496 #[test]
497 fn test_sym_csr_rank1_update() {
498 let mut matrix = create_test_sym_csr();
499 let x = Array1::from_vec(vec![1.0, 0.0, 0.0]);
500 let alpha = 3.0;
501
502 // Original diagonal element at (0,0) is 2.0
503 // After rank-1 update with [1,0,0] and alpha=3, it should be 2+3*1*1 = 5
504 sym_csr_rank1_update(&mut matrix, &x.view(), alpha).unwrap();
505
506 // Check the updated value
507 assert_relative_eq!(matrix.get(0, 0), 5.0);
508
509 // Other values should remain unchanged
510 assert_relative_eq!(matrix.get(0, 1), 1.0);
511 assert_relative_eq!(matrix.get(1, 1), 2.0);
512 assert_relative_eq!(matrix.get(1, 2), 3.0);
513 assert_relative_eq!(matrix.get(2, 2), 1.0);
514 }
515
516 #[test]
517 fn test_sym_csr_quadratic_form() {
518 let matrix = create_test_sym_csr();
519 let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
520
521 let result = sym_csr_quadratic_form(&matrix, &x.view()).unwrap();
522
523 // Expected result: [1,2,3] * [2,1,0; 1,2,3; 0,3,1] * [1;2;3]
524 // = [1,2,3] * [4,14,9] = 1*4 + 2*14 + 3*9 = 4 + 28 + 27 = 59
525 assert_relative_eq!(result, 59.0);
526 }
527
528 #[test]
529 fn test_sym_csr_trace() {
530 let matrix = create_test_sym_csr();
531
532 let trace = sym_csr_trace(&matrix);
533
534 // Expected: 2 + 2 + 1 = 5
535 assert_relative_eq!(trace, 5.0);
536 }
537
538 #[test]
539 fn test_compare_with_standard_matvec() {
540 // Create matrices and vectors
541 let sym_csr = create_test_sym_csr();
542 let full_csr = sym_csr.to_csr().unwrap();
543 let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
544
545 // Compute using the optimized function
546 let y_optimized = sym_csr_matvec(&sym_csr, &x.view()).unwrap();
547
548 // Compute using the standard function
549 let linear_op = full_csr.as_linear_operator();
550 let y_standard = linear_op.matvec(x.as_slice().unwrap()).unwrap();
551
552 // Compare results
553 for i in 0..y_optimized.len() {
554 assert_relative_eq!(y_optimized[i], y_standard[i]);
555 }
556 }
557}