scirs2_sparse/linalg/
interface.rs

1//! Linear operator interface for sparse matrices
2
3#![allow(unused_variables)]
4#![allow(unused_assignments)]
5#![allow(unused_mut)]
6
7use crate::error::{SparseError, SparseResult};
8use crate::sparray::SparseArray;
9use scirs2_core::numeric::{Float, NumAssign};
10use std::fmt::Debug;
11use std::iter::Sum;
12use std::marker::PhantomData;
13
14/// Type alias for matrix-vector function
15type MatVecFn<F> = Box<dyn Fn(&[F]) -> SparseResult<Vec<F>> + Send + Sync>;
16
17/// Type alias for solver function
18type SolverFn<F> = Box<dyn Fn(&[F]) -> SparseResult<Vec<F>> + Send + Sync>;
19
20/// Trait for representing a linear operator
21///
22/// This trait provides an abstract interface for linear operators,
23/// allowing matrix-free implementations and compositions.
24pub trait LinearOperator<F: Float> {
25    /// The shape of the operator (rows, columns)
26    fn shape(&self) -> (usize, usize);
27
28    /// Apply the operator to a vector: y = A * x
29    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>>;
30
31    /// Apply the operator to a matrix: Y = A * X
32    /// where X is column-major (each column is a vector)
33    fn matmat(&self, x: &[Vec<F>]) -> SparseResult<Vec<Vec<F>>> {
34        let mut result = Vec::new();
35        for col in x {
36            result.push(self.matvec(col)?);
37        }
38        Ok(result)
39    }
40
41    /// Apply the adjoint of the operator to a vector: y = A^H * x
42    /// Default implementation returns an error
43    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
44        Err(crate::error::SparseError::OperationNotSupported(
45            "adjoint not implemented for this operator".to_string(),
46        ))
47    }
48
49    /// Apply the adjoint of the operator to a matrix: Y = A^H * X
50    /// Default implementation calls rmatvec for each column
51    fn rmatmat(&self, x: &[Vec<F>]) -> SparseResult<Vec<Vec<F>>> {
52        let mut result = Vec::new();
53        for col in x {
54            result.push(self.rmatvec(col)?);
55        }
56        Ok(result)
57    }
58
59    /// Check if the operator supports adjoint operations
60    fn has_adjoint(&self) -> bool {
61        false
62    }
63}
64
65/// Identity operator: I * x = x
66#[derive(Clone)]
67pub struct IdentityOperator<F> {
68    size: usize,
69    phantom: PhantomData<F>,
70}
71
72impl<F> IdentityOperator<F> {
73    /// Create a new identity operator of given size
74    pub fn new(size: usize) -> Self {
75        Self {
76            size,
77            phantom: PhantomData,
78        }
79    }
80}
81
82impl<F: Float> LinearOperator<F> for IdentityOperator<F> {
83    fn shape(&self) -> (usize, usize) {
84        (self.size, self.size)
85    }
86
87    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
88        if x.len() != self.size {
89            return Err(crate::error::SparseError::DimensionMismatch {
90                expected: self.size,
91                found: x.len(),
92            });
93        }
94        Ok(x.to_vec())
95    }
96
97    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
98        self.matvec(x)
99    }
100
101    fn has_adjoint(&self) -> bool {
102        true
103    }
104}
105
106/// Scaled identity operator: (alpha * I) * x = alpha * x
107#[derive(Clone)]
108pub struct ScaledIdentityOperator<F> {
109    size: usize,
110    scale: F,
111}
112
113impl<F: Float> ScaledIdentityOperator<F> {
114    /// Create a new scaled identity operator
115    pub fn new(size: usize, scale: F) -> Self {
116        Self { size, scale }
117    }
118}
119
120impl<F: Float + NumAssign> LinearOperator<F> for ScaledIdentityOperator<F> {
121    fn shape(&self) -> (usize, usize) {
122        (self.size, self.size)
123    }
124
125    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
126        if x.len() != self.size {
127            return Err(crate::error::SparseError::DimensionMismatch {
128                expected: self.size,
129                found: x.len(),
130            });
131        }
132        Ok(x.iter().map(|&xi| xi * self.scale).collect())
133    }
134
135    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
136        // For real scalars, adjoint is the same
137        self.matvec(x)
138    }
139
140    fn has_adjoint(&self) -> bool {
141        true
142    }
143}
144
145/// Diagonal operator: D * x where D is a diagonal matrix
146#[derive(Clone)]
147pub struct DiagonalOperator<F> {
148    diagonal: Vec<F>,
149}
150
151impl<F: Float> DiagonalOperator<F> {
152    /// Create a new diagonal operator from diagonal values
153    pub fn new(diagonal: Vec<F>) -> Self {
154        Self { diagonal }
155    }
156
157    /// Get the diagonal values
158    pub fn diagonal(&self) -> &[F] {
159        &self.diagonal
160    }
161}
162
163impl<F: Float + NumAssign> LinearOperator<F> for DiagonalOperator<F> {
164    fn shape(&self) -> (usize, usize) {
165        let n = self.diagonal.len();
166        (n, n)
167    }
168
169    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
170        if x.len() != self.diagonal.len() {
171            return Err(crate::error::SparseError::DimensionMismatch {
172                expected: self.diagonal.len(),
173                found: x.len(),
174            });
175        }
176        Ok(x.iter()
177            .zip(&self.diagonal)
178            .map(|(&xi, &di)| xi * di)
179            .collect())
180    }
181
182    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
183        // For real diagonal matrices, adjoint is the same
184        self.matvec(x)
185    }
186
187    fn has_adjoint(&self) -> bool {
188        true
189    }
190}
191
192/// Zero operator: 0 * x = 0
193#[derive(Clone)]
194pub struct ZeroOperator<F> {
195    shape: (usize, usize),
196    _phantom: PhantomData<F>,
197}
198
199impl<F> ZeroOperator<F> {
200    /// Create a new zero operator with given shape
201    #[allow(dead_code)]
202    pub fn new(rows: usize, cols: usize) -> Self {
203        Self {
204            shape: (rows, cols),
205            _phantom: PhantomData,
206        }
207    }
208}
209
210impl<F: Float> LinearOperator<F> for ZeroOperator<F> {
211    fn shape(&self) -> (usize, usize) {
212        self.shape
213    }
214
215    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
216        if x.len() != self.shape.1 {
217            return Err(crate::error::SparseError::DimensionMismatch {
218                expected: self.shape.1,
219                found: x.len(),
220            });
221        }
222        Ok(vec![F::zero(); self.shape.0])
223    }
224
225    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
226        if x.len() != self.shape.0 {
227            return Err(crate::error::SparseError::DimensionMismatch {
228                expected: self.shape.0,
229                found: x.len(),
230            });
231        }
232        Ok(vec![F::zero(); self.shape.1])
233    }
234
235    fn has_adjoint(&self) -> bool {
236        true
237    }
238}
239
240/// Convert a sparse matrix to a linear operator
241pub trait AsLinearOperator<F: Float> {
242    /// Convert to a linear operator
243    fn as_linear_operator(&self) -> Box<dyn LinearOperator<F>>;
244}
245
246/// Linear operator wrapper for sparse matrices
247pub struct MatrixLinearOperator<F, M> {
248    matrix: M,
249    phantom: PhantomData<F>,
250}
251
252impl<F, M> MatrixLinearOperator<F, M> {
253    /// Create a new matrix linear operator
254    pub fn new(matrix: M) -> Self {
255        Self {
256            matrix,
257            phantom: PhantomData,
258        }
259    }
260}
261
262// Implementation of LinearOperator for CSR matrices
263use crate::csr::CsrMatrix;
264
265impl<F: Float + NumAssign + Sum + 'static + Debug> LinearOperator<F>
266    for MatrixLinearOperator<F, CsrMatrix<F>>
267{
268    fn shape(&self) -> (usize, usize) {
269        (self.matrix.rows(), self.matrix.cols())
270    }
271
272    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
273        if x.len() != self.matrix.cols() {
274            return Err(SparseError::DimensionMismatch {
275                expected: self.matrix.cols(),
276                found: x.len(),
277            });
278        }
279
280        // Manual implementation for generic types
281        let mut result = vec![F::zero(); self.matrix.rows()];
282        for (row, result_elem) in result.iter_mut().enumerate().take(self.matrix.rows()) {
283            let row_range = self.matrix.row_range(row);
284            let row_indices = &self.matrix.colindices()[row_range.clone()];
285            let row_data = &self.matrix.data[row_range];
286
287            let mut sum = F::zero();
288            for (col_idx, &col) in row_indices.iter().enumerate() {
289                sum += row_data[col_idx] * x[col];
290            }
291            *result_elem = sum;
292        }
293        Ok(result)
294    }
295
296    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
297        // For CSR, we can compute A^T * x by transposing first
298        let transposed = self.matrix.transpose();
299        MatrixLinearOperator::new(transposed).matvec(x)
300    }
301
302    fn has_adjoint(&self) -> bool {
303        true
304    }
305}
306
307// Implementation of LinearOperator for CsrArray
308use crate::csr_array::CsrArray;
309
310impl<F: Float + NumAssign + Sum + 'static + Debug> LinearOperator<F>
311    for MatrixLinearOperator<F, CsrArray<F>>
312{
313    fn shape(&self) -> (usize, usize) {
314        self.matrix.shape()
315    }
316
317    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
318        if x.len() != self.matrix.shape().1 {
319            return Err(SparseError::DimensionMismatch {
320                expected: self.matrix.shape().1,
321                found: x.len(),
322            });
323        }
324
325        use scirs2_core::ndarray::Array1;
326        let x_array = Array1::from_vec(x.to_vec());
327        let result = self.matrix.dot_vector(&x_array.view())?;
328        Ok(result.to_vec())
329    }
330
331    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
332        // For CSR A^T * x, we iterate through columns
333        if x.len() != self.matrix.shape().0 {
334            return Err(SparseError::DimensionMismatch {
335                expected: self.matrix.shape().0,
336                found: x.len(),
337            });
338        }
339
340        let mut result = vec![F::zero(); self.matrix.shape().1];
341
342        // Iterate through each row of the matrix
343        for (row_idx, &x_val) in x.iter().enumerate() {
344            if x_val != F::zero() {
345                // Get row data for this row
346                let row_start = self.matrix.get_indptr()[row_idx];
347                let row_end = self.matrix.get_indptr()[row_idx + 1];
348
349                for idx in row_start..row_end {
350                    let col_idx = self.matrix.get_indices()[idx];
351                    let data_val = self.matrix.get_data()[idx];
352                    result[col_idx] += data_val * x_val;
353                }
354            }
355        }
356
357        Ok(result)
358    }
359
360    fn has_adjoint(&self) -> bool {
361        true
362    }
363}
364
365impl<F: Float + NumAssign + Sum + 'static + Debug> AsLinearOperator<F> for CsrMatrix<F> {
366    fn as_linear_operator(&self) -> Box<dyn LinearOperator<F>> {
367        Box::new(MatrixLinearOperator::new(self.clone()))
368    }
369}
370
371impl<F: Float + NumAssign + Sum + 'static + Debug> AsLinearOperator<F>
372    for crate::csr_array::CsrArray<F>
373{
374    fn as_linear_operator(&self) -> Box<dyn LinearOperator<F>> {
375        Box::new(MatrixLinearOperator::new(self.clone()))
376    }
377}
378
379// Composition operators for adding and multiplying operators
380/// Sum of two linear operators: (A + B) * x = A * x + B * x
381pub struct SumOperator<F> {
382    a: Box<dyn LinearOperator<F>>,
383    b: Box<dyn LinearOperator<F>>,
384}
385
386impl<F: Float + NumAssign> SumOperator<F> {
387    /// Create a new sum operator
388    #[allow(dead_code)]
389    pub fn new(a: Box<dyn LinearOperator<F>>, b: Box<dyn LinearOperator<F>>) -> SparseResult<Self> {
390        if a.shape() != b.shape() {
391            return Err(crate::error::SparseError::ShapeMismatch {
392                expected: a.shape(),
393                found: b.shape(),
394            });
395        }
396        Ok(Self { a, b })
397    }
398}
399
400impl<F: Float + NumAssign> LinearOperator<F> for SumOperator<F> {
401    fn shape(&self) -> (usize, usize) {
402        self.a.shape()
403    }
404
405    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
406        let a_result = self.a.matvec(x)?;
407        let b_result = self.b.matvec(x)?;
408        Ok(a_result
409            .iter()
410            .zip(&b_result)
411            .map(|(&a, &b)| a + b)
412            .collect())
413    }
414
415    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
416        if !self.a.has_adjoint() || !self.b.has_adjoint() {
417            return Err(crate::error::SparseError::OperationNotSupported(
418                "adjoint not supported for one or both operators".to_string(),
419            ));
420        }
421        let a_result = self.a.rmatvec(x)?;
422        let b_result = self.b.rmatvec(x)?;
423        Ok(a_result
424            .iter()
425            .zip(&b_result)
426            .map(|(&a, &b)| a + b)
427            .collect())
428    }
429
430    fn has_adjoint(&self) -> bool {
431        self.a.has_adjoint() && self.b.has_adjoint()
432    }
433}
434
435/// Product of two linear operators: (A * B) * x = A * (B * x)
436pub struct ProductOperator<F> {
437    a: Box<dyn LinearOperator<F>>,
438    b: Box<dyn LinearOperator<F>>,
439}
440
441impl<F: Float + NumAssign> ProductOperator<F> {
442    /// Create a new product operator
443    #[allow(dead_code)]
444    pub fn new(a: Box<dyn LinearOperator<F>>, b: Box<dyn LinearOperator<F>>) -> SparseResult<Self> {
445        let (_a_rows, a_cols) = a.shape();
446        let (b_rows, b_cols) = b.shape();
447        if a_cols != b_rows {
448            return Err(crate::error::SparseError::DimensionMismatch {
449                expected: a_cols,
450                found: b_rows,
451            });
452        }
453        Ok(Self { a, b })
454    }
455}
456
457impl<F: Float + NumAssign> LinearOperator<F> for ProductOperator<F> {
458    fn shape(&self) -> (usize, usize) {
459        let (a_rows, _) = self.a.shape();
460        let (_, b_cols) = self.b.shape();
461        (a_rows, b_cols)
462    }
463
464    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
465        let b_result = self.b.matvec(x)?;
466        self.a.matvec(&b_result)
467    }
468
469    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
470        if !self.a.has_adjoint() || !self.b.has_adjoint() {
471            return Err(crate::error::SparseError::OperationNotSupported(
472                "adjoint not supported for one or both operators".to_string(),
473            ));
474        }
475        // (A * B)^H = B^H * A^H
476        let a_result = self.a.rmatvec(x)?;
477        self.b.rmatvec(&a_result)
478    }
479
480    fn has_adjoint(&self) -> bool {
481        self.a.has_adjoint() && self.b.has_adjoint()
482    }
483}
484
485/// Function-based linear operator for matrix-free implementations
486pub struct FunctionOperator<F> {
487    shape: (usize, usize),
488    matvec_fn: MatVecFn<F>,
489    rmatvec_fn: Option<MatVecFn<F>>,
490}
491
492impl<F: Float + 'static> FunctionOperator<F> {
493    /// Create a new function-based operator
494    #[allow(dead_code)]
495    pub fn new<MV, RMV>(shape: (usize, usize), matvec_fn: MV, rmatvec_fn: Option<RMV>) -> Self
496    where
497        MV: Fn(&[F]) -> SparseResult<Vec<F>> + Send + Sync + 'static,
498        RMV: Fn(&[F]) -> SparseResult<Vec<F>> + Send + Sync + 'static,
499    {
500        Self {
501            shape,
502            matvec_fn: Box::new(matvec_fn),
503            rmatvec_fn: rmatvec_fn
504                .map(|f| Box::new(f) as Box<dyn Fn(&[F]) -> SparseResult<Vec<F>> + Send + Sync>),
505        }
506    }
507
508    /// Create a matrix-free operator from a function
509    #[allow(dead_code)]
510    pub fn from_function<FMv>(shape: (usize, usize), matvec_fn: FMv) -> Self
511    where
512        FMv: Fn(&[F]) -> SparseResult<Vec<F>> + Send + Sync + 'static,
513    {
514        Self::new(shape, matvec_fn, None::<fn(&[F]) -> SparseResult<Vec<F>>>)
515    }
516}
517
518impl<F: Float> LinearOperator<F> for FunctionOperator<F> {
519    fn shape(&self) -> (usize, usize) {
520        self.shape
521    }
522
523    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
524        (self.matvec_fn)(x)
525    }
526
527    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
528        match &self.rmatvec_fn {
529            Some(f) => f(x),
530            None => Err(SparseError::OperationNotSupported(
531                "adjoint not implemented for this function operator".to_string(),
532            )),
533        }
534    }
535
536    fn has_adjoint(&self) -> bool {
537        self.rmatvec_fn.is_some()
538    }
539}
540
541/// Inverse operator: A^(-1)
542/// Note: This is a conceptual operator, actual implementation depends on the specific matrix
543pub struct InverseOperator<F> {
544    original: Box<dyn LinearOperator<F>>,
545    solver_fn: SolverFn<F>,
546}
547
548impl<F: Float> InverseOperator<F> {
549    /// Create a new inverse operator with a custom solver function
550    #[allow(dead_code)]
551    pub fn new<S>(original: Box<dyn LinearOperator<F>>, solver_fn: S) -> SparseResult<Self>
552    where
553        S: Fn(&[F]) -> SparseResult<Vec<F>> + Send + Sync + 'static,
554    {
555        let (rows, cols) = original.shape();
556        if rows != cols {
557            return Err(SparseError::ValueError(
558                "Cannot invert non-square operator".to_string(),
559            ));
560        }
561
562        Ok(Self {
563            original,
564            solver_fn: Box::new(solver_fn),
565        })
566    }
567}
568
569impl<F: Float> LinearOperator<F> for InverseOperator<F> {
570    fn shape(&self) -> (usize, usize) {
571        self.original.shape()
572    }
573
574    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
575        // A^(-1) * x is equivalent to solving A * y = x for y
576        (self.solver_fn)(x)
577    }
578
579    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
580        // (A^(-1))^H = (A^H)^(-1)
581        // So we need to solve A^H * y = x for y
582        if !self.original.has_adjoint() {
583            return Err(SparseError::OperationNotSupported(
584                "adjoint not supported for original operator".to_string(),
585            ));
586        }
587
588        // This is a conceptual implementation - in practice, you'd need
589        // a solver for the adjoint system
590        Err(SparseError::OperationNotSupported(
591            "adjoint of inverse operator not yet implemented".to_string(),
592        ))
593    }
594
595    fn has_adjoint(&self) -> bool {
596        false // Simplified for now
597    }
598}
599
600/// Transpose operator: A^T
601pub struct TransposeOperator<F> {
602    original: Box<dyn LinearOperator<F>>,
603}
604
605impl<F: Float + NumAssign> TransposeOperator<F> {
606    /// Create a new transpose operator
607    pub fn new(original: Box<dyn LinearOperator<F>>) -> Self {
608        Self { original }
609    }
610}
611
612impl<F: Float + NumAssign> LinearOperator<F> for TransposeOperator<F> {
613    fn shape(&self) -> (usize, usize) {
614        let (rows, cols) = self.original.shape();
615        (cols, rows) // Transpose dimensions
616    }
617
618    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
619        // A^T * x = (A^H * x) for real matrices
620        self.original.rmatvec(x)
621    }
622
623    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
624        // (A^T)^H = A for real matrices
625        self.original.matvec(x)
626    }
627
628    fn has_adjoint(&self) -> bool {
629        true // Transpose always has adjoint
630    }
631}
632
633/// Adjoint operator: A^H (Hermitian transpose)
634pub struct AdjointOperator<F> {
635    original: Box<dyn LinearOperator<F>>,
636}
637
638impl<F: Float + NumAssign> AdjointOperator<F> {
639    /// Create a new adjoint operator
640    pub fn new(original: Box<dyn LinearOperator<F>>) -> SparseResult<Self> {
641        if !original.has_adjoint() {
642            return Err(SparseError::OperationNotSupported(
643                "Original operator does not support adjoint operations".to_string(),
644            ));
645        }
646        Ok(Self { original })
647    }
648}
649
650impl<F: Float + NumAssign> LinearOperator<F> for AdjointOperator<F> {
651    fn shape(&self) -> (usize, usize) {
652        let (rows, cols) = self.original.shape();
653        (cols, rows) // Adjoint transposes dimensions
654    }
655
656    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
657        self.original.rmatvec(x)
658    }
659
660    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
661        self.original.matvec(x)
662    }
663
664    fn has_adjoint(&self) -> bool {
665        true
666    }
667}
668
669/// Difference operator: A - B
670pub struct DifferenceOperator<F> {
671    a: Box<dyn LinearOperator<F>>,
672    b: Box<dyn LinearOperator<F>>,
673}
674
675impl<F: Float + NumAssign> DifferenceOperator<F> {
676    /// Create a new difference operator
677    pub fn new(a: Box<dyn LinearOperator<F>>, b: Box<dyn LinearOperator<F>>) -> SparseResult<Self> {
678        if a.shape() != b.shape() {
679            return Err(SparseError::ShapeMismatch {
680                expected: a.shape(),
681                found: b.shape(),
682            });
683        }
684        Ok(Self { a, b })
685    }
686}
687
688impl<F: Float + NumAssign> LinearOperator<F> for DifferenceOperator<F> {
689    fn shape(&self) -> (usize, usize) {
690        self.a.shape()
691    }
692
693    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
694        let a_result = self.a.matvec(x)?;
695        let b_result = self.b.matvec(x)?;
696        Ok(a_result
697            .iter()
698            .zip(&b_result)
699            .map(|(&a, &b)| a - b)
700            .collect())
701    }
702
703    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
704        if !self.a.has_adjoint() || !self.b.has_adjoint() {
705            return Err(SparseError::OperationNotSupported(
706                "adjoint not supported for one or both operators".to_string(),
707            ));
708        }
709        let a_result = self.a.rmatvec(x)?;
710        let b_result = self.b.rmatvec(x)?;
711        Ok(a_result
712            .iter()
713            .zip(&b_result)
714            .map(|(&a, &b)| a - b)
715            .collect())
716    }
717
718    fn has_adjoint(&self) -> bool {
719        self.a.has_adjoint() && self.b.has_adjoint()
720    }
721}
722
723/// Scaled operator: alpha * A
724pub struct ScaledOperator<F> {
725    alpha: F,
726    operator: Box<dyn LinearOperator<F>>,
727}
728
729impl<F: Float + NumAssign> ScaledOperator<F> {
730    /// Create a new scaled operator
731    pub fn new(alpha: F, operator: Box<dyn LinearOperator<F>>) -> Self {
732        Self { alpha, operator }
733    }
734}
735
736impl<F: Float + NumAssign> LinearOperator<F> for ScaledOperator<F> {
737    fn shape(&self) -> (usize, usize) {
738        self.operator.shape()
739    }
740
741    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
742        let result = self.operator.matvec(x)?;
743        Ok(result.iter().map(|&val| self.alpha * val).collect())
744    }
745
746    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
747        if !self.operator.has_adjoint() {
748            return Err(SparseError::OperationNotSupported(
749                "adjoint not supported for underlying operator".to_string(),
750            ));
751        }
752        let result = self.operator.rmatvec(x)?;
753        Ok(result.iter().map(|&val| self.alpha * val).collect())
754    }
755
756    fn has_adjoint(&self) -> bool {
757        self.operator.has_adjoint()
758    }
759}
760
761/// Chain/composition of multiple operators: A_n * A_(n-1) * ... * A_1
762pub struct ChainOperator<F> {
763    operators: Vec<Box<dyn LinearOperator<F>>>,
764    totalshape: (usize, usize),
765}
766
767impl<F: Float + NumAssign> ChainOperator<F> {
768    /// Create a new chain operator from a list of operators
769    /// Operators are applied from right to left (like function composition)
770    #[allow(dead_code)]
771    pub fn new(operators: Vec<Box<dyn LinearOperator<F>>>) -> SparseResult<Self> {
772        if operators.is_empty() {
773            return Err(SparseError::ValueError(
774                "Cannot create chain with no operators".to_string(),
775            ));
776        }
777
778        // Check dimension compatibility
779        #[allow(clippy::needless_range_loop)]
780        for i in 0..operators.len() - 1 {
781            let (_, a_cols) = operators[i].shape();
782            let (b_rows, _) = operators[i + 1].shape();
783            if a_cols != b_rows {
784                return Err(SparseError::DimensionMismatch {
785                    expected: a_cols,
786                    found: b_rows,
787                });
788            }
789        }
790
791        let (first_rows, _) = operators[0].shape();
792        let (_, last_cols) = operators.last().unwrap().shape();
793        let totalshape = (first_rows, last_cols);
794
795        Ok(Self {
796            operators,
797            totalshape,
798        })
799    }
800}
801
802impl<F: Float + NumAssign> LinearOperator<F> for ChainOperator<F> {
803    fn shape(&self) -> (usize, usize) {
804        self.totalshape
805    }
806
807    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
808        let mut result = x.to_vec();
809        // Apply operators from right to left
810        for op in self.operators.iter().rev() {
811            result = op.matvec(&result)?;
812        }
813        Ok(result)
814    }
815
816    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
817        // Check if all operators support adjoint
818        for op in &self.operators {
819            if !op.has_adjoint() {
820                return Err(SparseError::OperationNotSupported(
821                    "adjoint not supported for all operators in chain".to_string(),
822                ));
823            }
824        }
825
826        let mut result = x.to_vec();
827        // Apply adjoints from left to right (reverse order)
828        for op in &self.operators {
829            result = op.rmatvec(&result)?;
830        }
831        Ok(result)
832    }
833
834    fn has_adjoint(&self) -> bool {
835        self.operators.iter().all(|op| op.has_adjoint())
836    }
837}
838
839/// Power operator: A^n (for positive integer n)
840pub struct PowerOperator<F> {
841    operator: Box<dyn LinearOperator<F>>,
842    power: usize,
843}
844
845impl<F: Float + NumAssign> PowerOperator<F> {
846    /// Create a new power operator
847    pub fn new(operator: Box<dyn LinearOperator<F>>, power: usize) -> SparseResult<Self> {
848        let (rows, cols) = operator.shape();
849        if rows != cols {
850            return Err(SparseError::ValueError(
851                "Can only compute powers of square operators".to_string(),
852            ));
853        }
854        if power == 0 {
855            return Err(SparseError::ValueError(
856                "Power must be positive".to_string(),
857            ));
858        }
859        Ok(Self { operator, power })
860    }
861}
862
863impl<F: Float + NumAssign> LinearOperator<F> for PowerOperator<F> {
864    fn shape(&self) -> (usize, usize) {
865        self.operator.shape()
866    }
867
868    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
869        let mut result = x.to_vec();
870        for _ in 0..self.power {
871            result = self.operator.matvec(&result)?;
872        }
873        Ok(result)
874    }
875
876    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
877        if !self.operator.has_adjoint() {
878            return Err(SparseError::OperationNotSupported(
879                "adjoint not supported for underlying operator".to_string(),
880            ));
881        }
882        let mut result = x.to_vec();
883        for _ in 0..self.power {
884            result = self.operator.rmatvec(&result)?;
885        }
886        Ok(result)
887    }
888
889    fn has_adjoint(&self) -> bool {
890        self.operator.has_adjoint()
891    }
892}
893
894/// Enhanced LinearOperator trait with composition methods
895#[allow(dead_code)]
896pub trait LinearOperatorExt<F: Float + NumAssign>: LinearOperator<F> {
897    /// Add this operator with another
898    fn add(&self, other: Box<dyn LinearOperator<F>>) -> SparseResult<Box<dyn LinearOperator<F>>>;
899
900    /// Subtract another operator from this one
901    fn sub(&self, other: Box<dyn LinearOperator<F>>) -> SparseResult<Box<dyn LinearOperator<F>>>;
902
903    /// Multiply this operator with another (composition)
904    fn mul(&self, other: Box<dyn LinearOperator<F>>) -> SparseResult<Box<dyn LinearOperator<F>>>;
905
906    /// Scale this operator by a scalar
907    fn scale(&self, alpha: F) -> Box<dyn LinearOperator<F>>;
908
909    /// Transpose this operator
910    fn transpose(&self) -> Box<dyn LinearOperator<F>>;
911
912    /// Adjoint of this operator
913    fn adjoint(&self) -> SparseResult<Box<dyn LinearOperator<F>>>;
914
915    /// Power of this operator
916    fn pow(&self, n: usize) -> SparseResult<Box<dyn LinearOperator<F>>>;
917}
918
919// Specific implementations for each cloneable operator type
920macro_rules! impl_linear_operator_ext {
921    ($typ:ty) => {
922        impl<F: Float + NumAssign + Copy + 'static> LinearOperatorExt<F> for $typ {
923            fn add(
924                &self,
925                other: Box<dyn LinearOperator<F>>,
926            ) -> SparseResult<Box<dyn LinearOperator<F>>> {
927                let self_box = Box::new(self.clone());
928                Ok(Box::new(SumOperator::new(self_box, other)?))
929            }
930
931            fn sub(
932                &self,
933                other: Box<dyn LinearOperator<F>>,
934            ) -> SparseResult<Box<dyn LinearOperator<F>>> {
935                let self_box = Box::new(self.clone());
936                Ok(Box::new(DifferenceOperator::new(self_box, other)?))
937            }
938
939            fn mul(
940                &self,
941                other: Box<dyn LinearOperator<F>>,
942            ) -> SparseResult<Box<dyn LinearOperator<F>>> {
943                let self_box = Box::new(self.clone());
944                Ok(Box::new(ProductOperator::new(self_box, other)?))
945            }
946
947            fn scale(&self, alpha: F) -> Box<dyn LinearOperator<F>> {
948                let self_box = Box::new(self.clone());
949                Box::new(ScaledOperator::new(alpha, self_box))
950            }
951
952            fn transpose(&self) -> Box<dyn LinearOperator<F>> {
953                let self_box = Box::new(self.clone());
954                Box::new(TransposeOperator::new(self_box))
955            }
956
957            fn adjoint(&self) -> SparseResult<Box<dyn LinearOperator<F>>> {
958                let self_box = Box::new(self.clone());
959                Ok(Box::new(AdjointOperator::new(self_box)?))
960            }
961
962            fn pow(&self, n: usize) -> SparseResult<Box<dyn LinearOperator<F>>> {
963                let self_box = Box::new(self.clone());
964                Ok(Box::new(PowerOperator::new(self_box, n)?))
965            }
966        }
967    };
968}
969
970// Apply the macro to all cloneable operator types
971impl_linear_operator_ext!(IdentityOperator<F>);
972impl_linear_operator_ext!(ScaledIdentityOperator<F>);
973impl_linear_operator_ext!(DiagonalOperator<F>);
974impl_linear_operator_ext!(ZeroOperator<F>);
975
976/// Utility functions for operator composition
977/// Add two operators: left + right
978#[allow(dead_code)]
979pub fn add_operators<F: Float + NumAssign + 'static>(
980    left: Box<dyn LinearOperator<F>>,
981    right: Box<dyn LinearOperator<F>>,
982) -> SparseResult<Box<dyn LinearOperator<F>>> {
983    Ok(Box::new(SumOperator::new(left, right)?))
984}
985
986/// Subtract two operators: left - right
987#[allow(dead_code)]
988pub fn subtract_operators<F: Float + NumAssign + 'static>(
989    left: Box<dyn LinearOperator<F>>,
990    right: Box<dyn LinearOperator<F>>,
991) -> SparseResult<Box<dyn LinearOperator<F>>> {
992    Ok(Box::new(DifferenceOperator::new(left, right)?))
993}
994
995/// Multiply two operators: left * right  
996#[allow(dead_code)]
997pub fn multiply_operators<F: Float + NumAssign + 'static>(
998    left: Box<dyn LinearOperator<F>>,
999    right: Box<dyn LinearOperator<F>>,
1000) -> SparseResult<Box<dyn LinearOperator<F>>> {
1001    Ok(Box::new(ProductOperator::new(left, right)?))
1002}
1003
1004/// Scale an operator: alpha * operator
1005#[allow(dead_code)]
1006pub fn scale_operator<F: Float + NumAssign + 'static>(
1007    alpha: F,
1008    operator: Box<dyn LinearOperator<F>>,
1009) -> Box<dyn LinearOperator<F>> {
1010    Box::new(ScaledOperator::new(alpha, operator))
1011}
1012
1013/// Transpose an operator: A^T
1014#[allow(dead_code)]
1015pub fn transpose_operator<F: Float + NumAssign + 'static>(
1016    operator: Box<dyn LinearOperator<F>>,
1017) -> Box<dyn LinearOperator<F>> {
1018    Box::new(TransposeOperator::new(operator))
1019}
1020
1021/// Adjoint of an operator: A^H
1022#[allow(dead_code)]
1023pub fn adjoint_operator<F: Float + NumAssign + 'static>(
1024    operator: Box<dyn LinearOperator<F>>,
1025) -> SparseResult<Box<dyn LinearOperator<F>>> {
1026    Ok(Box::new(AdjointOperator::new(operator)?))
1027}
1028
1029/// Compose multiple operators: A_n * A_(n-1) * ... * A_1
1030#[allow(dead_code)]
1031pub fn compose_operators<F: Float + NumAssign + 'static>(
1032    operators: Vec<Box<dyn LinearOperator<F>>>,
1033) -> SparseResult<Box<dyn LinearOperator<F>>> {
1034    Ok(Box::new(ChainOperator::new(operators)?))
1035}
1036
1037/// Power of an operator: A^n
1038#[allow(dead_code)]
1039pub fn power_operator<F: Float + NumAssign + 'static>(
1040    operator: Box<dyn LinearOperator<F>>,
1041    n: usize,
1042) -> SparseResult<Box<dyn LinearOperator<F>>> {
1043    Ok(Box::new(PowerOperator::new(operator, n)?))
1044}
1045
1046#[cfg(test)]
1047mod tests {
1048    use super::*;
1049
1050    #[test]
1051    fn test_identity_operator() {
1052        let op = IdentityOperator::<f64>::new(3);
1053        let x = vec![1.0, 2.0, 3.0];
1054        let y = op.matvec(&x).unwrap();
1055        assert_eq!(x, y);
1056    }
1057
1058    #[test]
1059    fn test_scaled_identity_operator() {
1060        let op = ScaledIdentityOperator::new(3, 2.0);
1061        let x = vec![1.0, 2.0, 3.0];
1062        let y = op.matvec(&x).unwrap();
1063        assert_eq!(y, vec![2.0, 4.0, 6.0]);
1064    }
1065
1066    #[test]
1067    fn test_diagonal_operator() {
1068        let op = DiagonalOperator::new(vec![2.0, 3.0, 4.0]);
1069        let x = vec![1.0, 2.0, 3.0];
1070        let y = op.matvec(&x).unwrap();
1071        assert_eq!(y, vec![2.0, 6.0, 12.0]);
1072    }
1073
1074    #[test]
1075    fn test_zero_operator() {
1076        let op = ZeroOperator::<f64>::new(3, 3);
1077        let x = vec![1.0, 2.0, 3.0];
1078        let y = op.matvec(&x).unwrap();
1079        assert_eq!(y, vec![0.0, 0.0, 0.0]);
1080    }
1081
1082    #[test]
1083    fn test_sum_operator() {
1084        let id = Box::new(IdentityOperator::<f64>::new(3));
1085        let scaled = Box::new(ScaledIdentityOperator::new(3, 2.0));
1086        let sum = SumOperator::new(id, scaled).unwrap();
1087        let x = vec![1.0, 2.0, 3.0];
1088        let y = sum.matvec(&x).unwrap();
1089        assert_eq!(y, vec![3.0, 6.0, 9.0]); // (I + 2I) * x = 3x
1090    }
1091
1092    #[test]
1093    fn test_product_operator() {
1094        let id = Box::new(IdentityOperator::<f64>::new(3));
1095        let scaled = Box::new(ScaledIdentityOperator::new(3, 2.0));
1096        let product = ProductOperator::new(scaled, id).unwrap();
1097        let x = vec![1.0, 2.0, 3.0];
1098        let y = product.matvec(&x).unwrap();
1099        assert_eq!(y, vec![2.0, 4.0, 6.0]); // (2I * I) * x = 2x
1100    }
1101
1102    #[test]
1103    fn test_difference_operator() {
1104        let scaled_3 = Box::new(ScaledIdentityOperator::new(3, 3.0));
1105        let scaled_2 = Box::new(ScaledIdentityOperator::new(3, 2.0));
1106        let diff = DifferenceOperator::new(scaled_3, scaled_2).unwrap();
1107        let x = vec![1.0, 2.0, 3.0];
1108        let y = diff.matvec(&x).unwrap();
1109        assert_eq!(y, vec![1.0, 2.0, 3.0]); // (3I - 2I) * x = I * x = x
1110    }
1111
1112    #[test]
1113    fn test_scaled_operator() {
1114        let id = Box::new(IdentityOperator::<f64>::new(3));
1115        let scaled = ScaledOperator::new(5.0, id);
1116        let x = vec![1.0, 2.0, 3.0];
1117        let y = scaled.matvec(&x).unwrap();
1118        assert_eq!(y, vec![5.0, 10.0, 15.0]); // 5 * I * x = 5x
1119    }
1120
1121    #[test]
1122    fn test_transpose_operator() {
1123        let diag = Box::new(DiagonalOperator::new(vec![2.0, 3.0, 4.0]));
1124        let transpose = TransposeOperator::new(diag);
1125        let x = vec![1.0, 2.0, 3.0];
1126        let y = transpose.matvec(&x).unwrap();
1127        // For diagonal matrices, transpose equals original
1128        assert_eq!(y, vec![2.0, 6.0, 12.0]);
1129    }
1130
1131    #[test]
1132    fn test_adjoint_operator() {
1133        let diag = Box::new(DiagonalOperator::new(vec![2.0, 3.0, 4.0]));
1134        let adjoint = AdjointOperator::new(diag).unwrap();
1135        let x = vec![1.0, 2.0, 3.0];
1136        let y = adjoint.matvec(&x).unwrap();
1137        // For real diagonal matrices, adjoint equals original
1138        assert_eq!(y, vec![2.0, 6.0, 12.0]);
1139    }
1140
1141    #[test]
1142    fn test_chain_operator() {
1143        let op1 = Box::new(ScaledIdentityOperator::new(3, 2.0));
1144        let op2 = Box::new(ScaledIdentityOperator::new(3, 3.0));
1145        let chain = ChainOperator::new(vec![op1, op2]).unwrap();
1146        let x = vec![1.0, 2.0, 3.0];
1147        let y = chain.matvec(&x).unwrap();
1148        // Chain applies from right to left: (2I) * (3I) * x = 6I * x = 6x
1149        assert_eq!(y, vec![6.0, 12.0, 18.0]);
1150    }
1151
1152    #[test]
1153    fn test_power_operator() {
1154        let scaled = Box::new(ScaledIdentityOperator::new(3, 2.0));
1155        let power = PowerOperator::new(scaled, 3).unwrap();
1156        let x = vec![1.0, 2.0, 3.0];
1157        let y = power.matvec(&x).unwrap();
1158        // (2I)^3 * x = 8I * x = 8x
1159        assert_eq!(y, vec![8.0, 16.0, 24.0]);
1160    }
1161
1162    #[test]
1163    fn test_composition_utility_functions() {
1164        let id = Box::new(IdentityOperator::<f64>::new(3));
1165        let scaled = Box::new(ScaledIdentityOperator::new(3, 2.0));
1166
1167        // Test add_operators
1168        let sum = add_operators(id.clone(), scaled.clone()).unwrap();
1169        let x = vec![1.0, 2.0, 3.0];
1170        let y = sum.matvec(&x).unwrap();
1171        assert_eq!(y, vec![3.0, 6.0, 9.0]); // (I + 2I) * x = 3x
1172
1173        // Test subtract_operators
1174        let diff = subtract_operators(scaled.clone(), id.clone()).unwrap();
1175        let y2 = diff.matvec(&x).unwrap();
1176        assert_eq!(y2, vec![1.0, 2.0, 3.0]); // (2I - I) * x = x
1177
1178        // Test multiply_operators
1179        let product = multiply_operators(scaled.clone(), id.clone()).unwrap();
1180        let y3 = product.matvec(&x).unwrap();
1181        assert_eq!(y3, vec![2.0, 4.0, 6.0]); // (2I * I) * x = 2x
1182
1183        // Test scale_operator
1184        let scaled_op = scale_operator(3.0, id.clone());
1185        let y4 = scaled_op.matvec(&x).unwrap();
1186        assert_eq!(y4, vec![3.0, 6.0, 9.0]); // 3 * I * x = 3x
1187
1188        // Test transpose_operator
1189        let transpose = transpose_operator(scaled.clone());
1190        let y5 = transpose.matvec(&x).unwrap();
1191        assert_eq!(y5, vec![2.0, 4.0, 6.0]); // (2I)^T * x = 2I * x = 2x
1192
1193        // Test compose_operators
1194        let ops: Vec<Box<dyn LinearOperator<f64>>> = vec![scaled.clone(), id.clone()];
1195        let composed = compose_operators(ops).unwrap();
1196        let y6 = composed.matvec(&x).unwrap();
1197        assert_eq!(y6, vec![2.0, 4.0, 6.0]); // (2I * I) * x = 2x
1198
1199        // Test power_operator
1200        let power = power_operator(scaled.clone(), 2).unwrap();
1201        let y7 = power.matvec(&x).unwrap();
1202        assert_eq!(y7, vec![4.0, 8.0, 12.0]); // (2I)^2 * x = 4I * x = 4x
1203    }
1204
1205    #[test]
1206    fn test_dimension_mismatch_errors() {
1207        let op1 = Box::new(IdentityOperator::<f64>::new(3));
1208        let op2 = Box::new(IdentityOperator::<f64>::new(4));
1209
1210        // Test sum operator dimension mismatch
1211        assert!(SumOperator::new(op1.clone(), op2.clone()).is_err());
1212
1213        // Test difference operator dimension mismatch
1214        assert!(DifferenceOperator::new(op1.clone(), op2.clone()).is_err());
1215
1216        // Test product operator dimension mismatch (incompatible dimensions)
1217        let rect1 = Box::new(ZeroOperator::<f64>::new(3, 4));
1218        let rect2 = Box::new(ZeroOperator::<f64>::new(5, 3));
1219        assert!(ProductOperator::new(rect1, rect2).is_err());
1220    }
1221
1222    #[test]
1223    fn test_adjoint_not_supported_error() {
1224        // Create a function operator without adjoint support
1225        let func_op = Box::new(FunctionOperator::from_function((3, 3), |x: &[f64]| {
1226            Ok(x.to_vec())
1227        }));
1228
1229        // Attempting to create adjoint should fail
1230        assert!(AdjointOperator::new(func_op).is_err());
1231    }
1232
1233    #[test]
1234    fn test_power_operator_errors() {
1235        let rect_op = Box::new(ZeroOperator::<f64>::new(3, 4));
1236
1237        // Power of non-square operator should fail
1238        assert!(PowerOperator::new(rect_op, 2).is_err());
1239
1240        let square_op = Box::new(IdentityOperator::<f64>::new(3));
1241
1242        // Power of 0 should fail
1243        assert!(PowerOperator::new(square_op, 0).is_err());
1244    }
1245}