scirs2_sparse/linalg/
interface.rs

1//! Linear operator interface for sparse matrices
2
3use crate::error::{SparseError, SparseResult};
4use num_traits::{Float, NumAssign};
5use std::fmt::Debug;
6use std::iter::Sum;
7use std::marker::PhantomData;
8
9/// Trait for representing a linear operator
10///
11/// This trait provides an abstract interface for linear operators,
12/// allowing matrix-free implementations and compositions.
13pub trait LinearOperator<F: Float> {
14    /// The shape of the operator (rows, columns)
15    fn shape(&self) -> (usize, usize);
16
17    /// Apply the operator to a vector: y = A * x
18    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>>;
19
20    /// Apply the operator to a matrix: Y = A * X
21    /// where X is column-major (each column is a vector)
22    fn matmat(&self, x: &[Vec<F>]) -> SparseResult<Vec<Vec<F>>> {
23        let mut result = Vec::new();
24        for col in x {
25            result.push(self.matvec(col)?);
26        }
27        Ok(result)
28    }
29
30    /// Apply the adjoint of the operator to a vector: y = A^H * x
31    /// Default implementation returns an error
32    fn rmatvec(&self, _x: &[F]) -> SparseResult<Vec<F>> {
33        Err(crate::error::SparseError::OperationNotSupported(
34            "adjoint not implemented for this operator".to_string(),
35        ))
36    }
37
38    /// Apply the adjoint of the operator to a matrix: Y = A^H * X
39    /// Default implementation calls rmatvec for each column
40    fn rmatmat(&self, x: &[Vec<F>]) -> SparseResult<Vec<Vec<F>>> {
41        let mut result = Vec::new();
42        for col in x {
43            result.push(self.rmatvec(col)?);
44        }
45        Ok(result)
46    }
47
48    /// Check if the operator supports adjoint operations
49    fn has_adjoint(&self) -> bool {
50        false
51    }
52}
53
54/// Identity operator: I * x = x
55pub struct IdentityOperator<F> {
56    size: usize,
57    _phantom: PhantomData<F>,
58}
59
60impl<F> IdentityOperator<F> {
61    /// Create a new identity operator of given size
62    pub fn new(size: usize) -> Self {
63        Self {
64            size,
65            _phantom: PhantomData,
66        }
67    }
68}
69
70impl<F: Float> LinearOperator<F> for IdentityOperator<F> {
71    fn shape(&self) -> (usize, usize) {
72        (self.size, self.size)
73    }
74
75    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
76        if x.len() != self.size {
77            return Err(crate::error::SparseError::DimensionMismatch {
78                expected: self.size,
79                found: x.len(),
80            });
81        }
82        Ok(x.to_vec())
83    }
84
85    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
86        self.matvec(x)
87    }
88
89    fn has_adjoint(&self) -> bool {
90        true
91    }
92}
93
94/// Scaled identity operator: (alpha * I) * x = alpha * x
95pub struct ScaledIdentityOperator<F> {
96    size: usize,
97    scale: F,
98}
99
100impl<F: Float> ScaledIdentityOperator<F> {
101    /// Create a new scaled identity operator
102    pub fn new(size: usize, scale: F) -> Self {
103        Self { size, scale }
104    }
105}
106
107impl<F: Float + NumAssign> LinearOperator<F> for ScaledIdentityOperator<F> {
108    fn shape(&self) -> (usize, usize) {
109        (self.size, self.size)
110    }
111
112    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
113        if x.len() != self.size {
114            return Err(crate::error::SparseError::DimensionMismatch {
115                expected: self.size,
116                found: x.len(),
117            });
118        }
119        Ok(x.iter().map(|&xi| xi * self.scale).collect())
120    }
121
122    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
123        // For real scalars, adjoint is the same
124        self.matvec(x)
125    }
126
127    fn has_adjoint(&self) -> bool {
128        true
129    }
130}
131
132/// Diagonal operator: D * x where D is a diagonal matrix
133pub struct DiagonalOperator<F> {
134    diagonal: Vec<F>,
135}
136
137impl<F: Float> DiagonalOperator<F> {
138    /// Create a new diagonal operator from diagonal values
139    pub fn new(diagonal: Vec<F>) -> Self {
140        Self { diagonal }
141    }
142
143    /// Get the diagonal values
144    pub fn diagonal(&self) -> &[F] {
145        &self.diagonal
146    }
147}
148
149impl<F: Float + NumAssign> LinearOperator<F> for DiagonalOperator<F> {
150    fn shape(&self) -> (usize, usize) {
151        let n = self.diagonal.len();
152        (n, n)
153    }
154
155    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
156        if x.len() != self.diagonal.len() {
157            return Err(crate::error::SparseError::DimensionMismatch {
158                expected: self.diagonal.len(),
159                found: x.len(),
160            });
161        }
162        Ok(x.iter()
163            .zip(&self.diagonal)
164            .map(|(&xi, &di)| xi * di)
165            .collect())
166    }
167
168    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
169        // For real diagonal matrices, adjoint is the same
170        self.matvec(x)
171    }
172
173    fn has_adjoint(&self) -> bool {
174        true
175    }
176}
177
178/// Zero operator: 0 * x = 0
179pub struct ZeroOperator<F> {
180    shape: (usize, usize),
181    _phantom: PhantomData<F>,
182}
183
184impl<F> ZeroOperator<F> {
185    /// Create a new zero operator with given shape
186    #[allow(dead_code)]
187    pub fn new(rows: usize, cols: usize) -> Self {
188        Self {
189            shape: (rows, cols),
190            _phantom: PhantomData,
191        }
192    }
193}
194
195impl<F: Float> LinearOperator<F> for ZeroOperator<F> {
196    fn shape(&self) -> (usize, usize) {
197        self.shape
198    }
199
200    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
201        if x.len() != self.shape.1 {
202            return Err(crate::error::SparseError::DimensionMismatch {
203                expected: self.shape.1,
204                found: x.len(),
205            });
206        }
207        Ok(vec![F::zero(); self.shape.0])
208    }
209
210    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
211        if x.len() != self.shape.0 {
212            return Err(crate::error::SparseError::DimensionMismatch {
213                expected: self.shape.0,
214                found: x.len(),
215            });
216        }
217        Ok(vec![F::zero(); self.shape.1])
218    }
219
220    fn has_adjoint(&self) -> bool {
221        true
222    }
223}
224
225/// Convert a sparse matrix to a linear operator
226pub trait AsLinearOperator<F: Float> {
227    /// Convert to a linear operator
228    fn as_linear_operator(&self) -> Box<dyn LinearOperator<F>>;
229}
230
231/// Linear operator wrapper for sparse matrices
232pub struct MatrixLinearOperator<F, M> {
233    matrix: M,
234    _phantom: PhantomData<F>,
235}
236
237impl<F, M> MatrixLinearOperator<F, M> {
238    /// Create a new matrix linear operator
239    pub fn new(matrix: M) -> Self {
240        Self {
241            matrix,
242            _phantom: PhantomData,
243        }
244    }
245}
246
247// Implementation of LinearOperator for CSR matrices
248use crate::csr::CsrMatrix;
249
250impl<F: Float + NumAssign + Sum + 'static + Debug> LinearOperator<F>
251    for MatrixLinearOperator<F, CsrMatrix<F>>
252{
253    fn shape(&self) -> (usize, usize) {
254        (self.matrix.rows(), self.matrix.cols())
255    }
256
257    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
258        if x.len() != self.matrix.cols() {
259            return Err(SparseError::DimensionMismatch {
260                expected: self.matrix.cols(),
261                found: x.len(),
262            });
263        }
264
265        // Manual implementation for generic types
266        let mut result = vec![F::zero(); self.matrix.rows()];
267        for (row, result_elem) in result.iter_mut().enumerate().take(self.matrix.rows()) {
268            let row_range = self.matrix.row_range(row);
269            let row_indices = &self.matrix.col_indices()[row_range.clone()];
270            let row_data = &self.matrix.data[row_range];
271
272            let mut sum = F::zero();
273            for (col_idx, &col) in row_indices.iter().enumerate() {
274                sum += row_data[col_idx] * x[col];
275            }
276            *result_elem = sum;
277        }
278        Ok(result)
279    }
280
281    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
282        // For CSR, we can compute A^T * x by transposing first
283        let transposed = self.matrix.transpose();
284        MatrixLinearOperator::new(transposed).matvec(x)
285    }
286
287    fn has_adjoint(&self) -> bool {
288        true
289    }
290}
291
292impl<F: Float + NumAssign + Sum + 'static + Debug> AsLinearOperator<F> for CsrMatrix<F> {
293    fn as_linear_operator(&self) -> Box<dyn LinearOperator<F>> {
294        Box::new(MatrixLinearOperator::new(self.clone()))
295    }
296}
297
298// Composition operators for adding and multiplying operators
299/// Sum of two linear operators: (A + B) * x = A * x + B * x
300pub struct SumOperator<F> {
301    a: Box<dyn LinearOperator<F>>,
302    b: Box<dyn LinearOperator<F>>,
303}
304
305impl<F: Float + NumAssign> SumOperator<F> {
306    /// Create a new sum operator
307    #[allow(dead_code)]
308    pub fn new(a: Box<dyn LinearOperator<F>>, b: Box<dyn LinearOperator<F>>) -> SparseResult<Self> {
309        if a.shape() != b.shape() {
310            return Err(crate::error::SparseError::ShapeMismatch {
311                expected: a.shape(),
312                found: b.shape(),
313            });
314        }
315        Ok(Self { a, b })
316    }
317}
318
319impl<F: Float + NumAssign> LinearOperator<F> for SumOperator<F> {
320    fn shape(&self) -> (usize, usize) {
321        self.a.shape()
322    }
323
324    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
325        let a_result = self.a.matvec(x)?;
326        let b_result = self.b.matvec(x)?;
327        Ok(a_result
328            .iter()
329            .zip(&b_result)
330            .map(|(&a, &b)| a + b)
331            .collect())
332    }
333
334    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
335        if !self.a.has_adjoint() || !self.b.has_adjoint() {
336            return Err(crate::error::SparseError::OperationNotSupported(
337                "adjoint not supported for one or both operators".to_string(),
338            ));
339        }
340        let a_result = self.a.rmatvec(x)?;
341        let b_result = self.b.rmatvec(x)?;
342        Ok(a_result
343            .iter()
344            .zip(&b_result)
345            .map(|(&a, &b)| a + b)
346            .collect())
347    }
348
349    fn has_adjoint(&self) -> bool {
350        self.a.has_adjoint() && self.b.has_adjoint()
351    }
352}
353
354/// Product of two linear operators: (A * B) * x = A * (B * x)
355pub struct ProductOperator<F> {
356    a: Box<dyn LinearOperator<F>>,
357    b: Box<dyn LinearOperator<F>>,
358}
359
360impl<F: Float + NumAssign> ProductOperator<F> {
361    /// Create a new product operator
362    #[allow(dead_code)]
363    pub fn new(a: Box<dyn LinearOperator<F>>, b: Box<dyn LinearOperator<F>>) -> SparseResult<Self> {
364        let (_a_rows, a_cols) = a.shape();
365        let (b_rows, _b_cols) = b.shape();
366        if a_cols != b_rows {
367            return Err(crate::error::SparseError::DimensionMismatch {
368                expected: a_cols,
369                found: b_rows,
370            });
371        }
372        Ok(Self { a, b })
373    }
374}
375
376impl<F: Float + NumAssign> LinearOperator<F> for ProductOperator<F> {
377    fn shape(&self) -> (usize, usize) {
378        let (a_rows, _) = self.a.shape();
379        let (_, b_cols) = self.b.shape();
380        (a_rows, b_cols)
381    }
382
383    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
384        let b_result = self.b.matvec(x)?;
385        self.a.matvec(&b_result)
386    }
387
388    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
389        if !self.a.has_adjoint() || !self.b.has_adjoint() {
390            return Err(crate::error::SparseError::OperationNotSupported(
391                "adjoint not supported for one or both operators".to_string(),
392            ));
393        }
394        // (A * B)^H = B^H * A^H
395        let a_result = self.a.rmatvec(x)?;
396        self.b.rmatvec(&a_result)
397    }
398
399    fn has_adjoint(&self) -> bool {
400        self.a.has_adjoint() && self.b.has_adjoint()
401    }
402}
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407
408    #[test]
409    fn test_identity_operator() {
410        let op = IdentityOperator::<f64>::new(3);
411        let x = vec![1.0, 2.0, 3.0];
412        let y = op.matvec(&x).unwrap();
413        assert_eq!(x, y);
414    }
415
416    #[test]
417    fn test_scaled_identity_operator() {
418        let op = ScaledIdentityOperator::new(3, 2.0);
419        let x = vec![1.0, 2.0, 3.0];
420        let y = op.matvec(&x).unwrap();
421        assert_eq!(y, vec![2.0, 4.0, 6.0]);
422    }
423
424    #[test]
425    fn test_diagonal_operator() {
426        let op = DiagonalOperator::new(vec![2.0, 3.0, 4.0]);
427        let x = vec![1.0, 2.0, 3.0];
428        let y = op.matvec(&x).unwrap();
429        assert_eq!(y, vec![2.0, 6.0, 12.0]);
430    }
431
432    #[test]
433    fn test_zero_operator() {
434        let op = ZeroOperator::<f64>::new(3, 3);
435        let x = vec![1.0, 2.0, 3.0];
436        let y = op.matvec(&x).unwrap();
437        assert_eq!(y, vec![0.0, 0.0, 0.0]);
438    }
439
440    #[test]
441    fn test_sum_operator() {
442        let id = Box::new(IdentityOperator::<f64>::new(3));
443        let scaled = Box::new(ScaledIdentityOperator::new(3, 2.0));
444        let sum = SumOperator::new(id, scaled).unwrap();
445        let x = vec![1.0, 2.0, 3.0];
446        let y = sum.matvec(&x).unwrap();
447        assert_eq!(y, vec![3.0, 6.0, 9.0]); // (I + 2I) * x = 3x
448    }
449
450    #[test]
451    fn test_product_operator() {
452        let id = Box::new(IdentityOperator::<f64>::new(3));
453        let scaled = Box::new(ScaledIdentityOperator::new(3, 2.0));
454        let product = ProductOperator::new(scaled, id).unwrap();
455        let x = vec![1.0, 2.0, 3.0];
456        let y = product.matvec(&x).unwrap();
457        assert_eq!(y, vec![2.0, 4.0, 6.0]); // (2I * I) * x = 2x
458    }
459}