scirs2_sparse/linalg/
interface.rs

1//! Linear operator interface for sparse matrices
2
3#![allow(unused_variables)]
4#![allow(unused_assignments)]
5#![allow(unused_mut)]
6
7use crate::error::{SparseError, SparseResult};
8use crate::sparray::SparseArray;
9use scirs2_core::numeric::{Float, NumAssign};
10use scirs2_core::SparseElement;
11use std::fmt::Debug;
12use std::iter::Sum;
13use std::marker::PhantomData;
14
15/// Type alias for matrix-vector function
16type MatVecFn<F> = Box<dyn Fn(&[F]) -> SparseResult<Vec<F>> + Send + Sync>;
17
18/// Type alias for solver function
19type SolverFn<F> = Box<dyn Fn(&[F]) -> SparseResult<Vec<F>> + Send + Sync>;
20
21/// Trait for representing a linear operator
22///
23/// This trait provides an abstract interface for linear operators,
24/// allowing matrix-free implementations and compositions.
25pub trait LinearOperator<F: Float> {
26    /// The shape of the operator (rows, columns)
27    fn shape(&self) -> (usize, usize);
28
29    /// Apply the operator to a vector: y = A * x
30    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>>;
31
32    /// Apply the operator to a matrix: Y = A * X
33    /// where X is column-major (each column is a vector)
34    fn matmat(&self, x: &[Vec<F>]) -> SparseResult<Vec<Vec<F>>> {
35        let mut result = Vec::new();
36        for col in x {
37            result.push(self.matvec(col)?);
38        }
39        Ok(result)
40    }
41
42    /// Apply the adjoint of the operator to a vector: y = A^H * x
43    /// Default implementation returns an error
44    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
45        Err(crate::error::SparseError::OperationNotSupported(
46            "adjoint not implemented for this operator".to_string(),
47        ))
48    }
49
50    /// Apply the adjoint of the operator to a matrix: Y = A^H * X
51    /// Default implementation calls rmatvec for each column
52    fn rmatmat(&self, x: &[Vec<F>]) -> SparseResult<Vec<Vec<F>>> {
53        let mut result = Vec::new();
54        for col in x {
55            result.push(self.rmatvec(col)?);
56        }
57        Ok(result)
58    }
59
60    /// Check if the operator supports adjoint operations
61    fn has_adjoint(&self) -> bool {
62        false
63    }
64}
65
66/// Identity operator: I * x = x
67#[derive(Clone)]
68pub struct IdentityOperator<F> {
69    size: usize,
70    phantom: PhantomData<F>,
71}
72
73impl<F> IdentityOperator<F> {
74    /// Create a new identity operator of given size
75    pub fn new(size: usize) -> Self {
76        Self {
77            size,
78            phantom: PhantomData,
79        }
80    }
81}
82
83impl<F: Float> LinearOperator<F> for IdentityOperator<F> {
84    fn shape(&self) -> (usize, usize) {
85        (self.size, self.size)
86    }
87
88    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
89        if x.len() != self.size {
90            return Err(crate::error::SparseError::DimensionMismatch {
91                expected: self.size,
92                found: x.len(),
93            });
94        }
95        Ok(x.to_vec())
96    }
97
98    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
99        self.matvec(x)
100    }
101
102    fn has_adjoint(&self) -> bool {
103        true
104    }
105}
106
107/// Scaled identity operator: (alpha * I) * x = alpha * x
108#[derive(Clone)]
109pub struct ScaledIdentityOperator<F> {
110    size: usize,
111    scale: F,
112}
113
114impl<F: Float> ScaledIdentityOperator<F> {
115    /// Create a new scaled identity operator
116    pub fn new(size: usize, scale: F) -> Self {
117        Self { size, scale }
118    }
119}
120
121impl<F: Float + NumAssign> LinearOperator<F> for ScaledIdentityOperator<F> {
122    fn shape(&self) -> (usize, usize) {
123        (self.size, self.size)
124    }
125
126    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
127        if x.len() != self.size {
128            return Err(crate::error::SparseError::DimensionMismatch {
129                expected: self.size,
130                found: x.len(),
131            });
132        }
133        Ok(x.iter().map(|&xi| xi * self.scale).collect())
134    }
135
136    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
137        // For real scalars, adjoint is the same
138        self.matvec(x)
139    }
140
141    fn has_adjoint(&self) -> bool {
142        true
143    }
144}
145
146/// Diagonal operator: D * x where D is a diagonal matrix
147#[derive(Clone)]
148pub struct DiagonalOperator<F> {
149    diagonal: Vec<F>,
150}
151
152impl<F: Float> DiagonalOperator<F> {
153    /// Create a new diagonal operator from diagonal values
154    pub fn new(diagonal: Vec<F>) -> Self {
155        Self { diagonal }
156    }
157
158    /// Get the diagonal values
159    pub fn diagonal(&self) -> &[F] {
160        &self.diagonal
161    }
162}
163
164impl<F: Float + NumAssign> LinearOperator<F> for DiagonalOperator<F> {
165    fn shape(&self) -> (usize, usize) {
166        let n = self.diagonal.len();
167        (n, n)
168    }
169
170    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
171        if x.len() != self.diagonal.len() {
172            return Err(crate::error::SparseError::DimensionMismatch {
173                expected: self.diagonal.len(),
174                found: x.len(),
175            });
176        }
177        Ok(x.iter()
178            .zip(&self.diagonal)
179            .map(|(&xi, &di)| xi * di)
180            .collect())
181    }
182
183    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
184        // For real diagonal matrices, adjoint is the same
185        self.matvec(x)
186    }
187
188    fn has_adjoint(&self) -> bool {
189        true
190    }
191}
192
193/// Zero operator: 0 * x = 0
194#[derive(Clone)]
195pub struct ZeroOperator<F> {
196    shape: (usize, usize),
197    _phantom: PhantomData<F>,
198}
199
200impl<F> ZeroOperator<F> {
201    /// Create a new zero operator with given shape
202    #[allow(dead_code)]
203    pub fn new(rows: usize, cols: usize) -> Self {
204        Self {
205            shape: (rows, cols),
206            _phantom: PhantomData,
207        }
208    }
209}
210
211impl<F: Float + SparseElement> LinearOperator<F> for ZeroOperator<F> {
212    fn shape(&self) -> (usize, usize) {
213        self.shape
214    }
215
216    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
217        if x.len() != self.shape.1 {
218            return Err(crate::error::SparseError::DimensionMismatch {
219                expected: self.shape.1,
220                found: x.len(),
221            });
222        }
223        Ok(vec![F::sparse_zero(); self.shape.0])
224    }
225
226    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
227        if x.len() != self.shape.0 {
228            return Err(crate::error::SparseError::DimensionMismatch {
229                expected: self.shape.0,
230                found: x.len(),
231            });
232        }
233        Ok(vec![F::sparse_zero(); self.shape.1])
234    }
235
236    fn has_adjoint(&self) -> bool {
237        true
238    }
239}
240
241/// Convert a sparse matrix to a linear operator
242pub trait AsLinearOperator<F: Float> {
243    /// Convert to a linear operator
244    fn as_linear_operator(&self) -> Box<dyn LinearOperator<F>>;
245}
246
247/// Linear operator wrapper for sparse matrices
248pub struct MatrixLinearOperator<F, M> {
249    matrix: M,
250    phantom: PhantomData<F>,
251}
252
253impl<F, M> MatrixLinearOperator<F, M> {
254    /// Create a new matrix linear operator
255    pub fn new(matrix: M) -> Self {
256        Self {
257            matrix,
258            phantom: PhantomData,
259        }
260    }
261}
262
263// Implementation of LinearOperator for CSR matrices
264use crate::csr::CsrMatrix;
265
266impl<F: Float + SparseElement + NumAssign + Sum + 'static + Debug> LinearOperator<F>
267    for MatrixLinearOperator<F, CsrMatrix<F>>
268{
269    fn shape(&self) -> (usize, usize) {
270        (self.matrix.rows(), self.matrix.cols())
271    }
272
273    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
274        if x.len() != self.matrix.cols() {
275            return Err(SparseError::DimensionMismatch {
276                expected: self.matrix.cols(),
277                found: x.len(),
278            });
279        }
280
281        // Manual implementation for generic types
282        let mut result = vec![F::sparse_zero(); self.matrix.rows()];
283        for (row, result_elem) in result.iter_mut().enumerate().take(self.matrix.rows()) {
284            let row_range = self.matrix.row_range(row);
285            let row_indices = &self.matrix.colindices()[row_range.clone()];
286            let row_data = &self.matrix.data[row_range];
287
288            let mut sum = F::sparse_zero();
289            for (col_idx, &col) in row_indices.iter().enumerate() {
290                sum += row_data[col_idx] * x[col];
291            }
292            *result_elem = sum;
293        }
294        Ok(result)
295    }
296
297    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
298        // For CSR, we can compute A^T * x by transposing first
299        let transposed = self.matrix.transpose();
300        MatrixLinearOperator::new(transposed).matvec(x)
301    }
302
303    fn has_adjoint(&self) -> bool {
304        true
305    }
306}
307
308// Implementation of LinearOperator for CsrArray
309use crate::csr_array::CsrArray;
310
311impl<F: Float + SparseElement + NumAssign + Sum + 'static + Debug> LinearOperator<F>
312    for MatrixLinearOperator<F, CsrArray<F>>
313{
314    fn shape(&self) -> (usize, usize) {
315        self.matrix.shape()
316    }
317
318    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
319        if x.len() != self.matrix.shape().1 {
320            return Err(SparseError::DimensionMismatch {
321                expected: self.matrix.shape().1,
322                found: x.len(),
323            });
324        }
325
326        use scirs2_core::ndarray::Array1;
327        let x_array = Array1::from_vec(x.to_vec());
328        let result = self.matrix.dot_vector(&x_array.view())?;
329        Ok(result.to_vec())
330    }
331
332    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
333        // For CSR A^T * x, we iterate through columns
334        if x.len() != self.matrix.shape().0 {
335            return Err(SparseError::DimensionMismatch {
336                expected: self.matrix.shape().0,
337                found: x.len(),
338            });
339        }
340
341        let mut result = vec![F::sparse_zero(); self.matrix.shape().1];
342
343        // Iterate through each row of the matrix
344        for (row_idx, &x_val) in x.iter().enumerate() {
345            if x_val != F::sparse_zero() {
346                // Get row data for this row
347                let row_start = self.matrix.get_indptr()[row_idx];
348                let row_end = self.matrix.get_indptr()[row_idx + 1];
349
350                for idx in row_start..row_end {
351                    let col_idx = self.matrix.get_indices()[idx];
352                    let data_val = self.matrix.get_data()[idx];
353                    result[col_idx] += data_val * x_val;
354                }
355            }
356        }
357
358        Ok(result)
359    }
360
361    fn has_adjoint(&self) -> bool {
362        true
363    }
364}
365
366impl<F: Float + SparseElement + NumAssign + Sum + 'static + Debug> AsLinearOperator<F>
367    for CsrMatrix<F>
368{
369    fn as_linear_operator(&self) -> Box<dyn LinearOperator<F>> {
370        Box::new(MatrixLinearOperator::new(self.clone()))
371    }
372}
373
374impl<F: Float + SparseElement + NumAssign + Sum + 'static + Debug> AsLinearOperator<F>
375    for crate::csr_array::CsrArray<F>
376{
377    fn as_linear_operator(&self) -> Box<dyn LinearOperator<F>> {
378        Box::new(MatrixLinearOperator::new(self.clone()))
379    }
380}
381
382// Composition operators for adding and multiplying operators
383/// Sum of two linear operators: (A + B) * x = A * x + B * x
384pub struct SumOperator<F> {
385    a: Box<dyn LinearOperator<F>>,
386    b: Box<dyn LinearOperator<F>>,
387}
388
389impl<F: Float + NumAssign> SumOperator<F> {
390    /// Create a new sum operator
391    #[allow(dead_code)]
392    pub fn new(a: Box<dyn LinearOperator<F>>, b: Box<dyn LinearOperator<F>>) -> SparseResult<Self> {
393        if a.shape() != b.shape() {
394            return Err(crate::error::SparseError::ShapeMismatch {
395                expected: a.shape(),
396                found: b.shape(),
397            });
398        }
399        Ok(Self { a, b })
400    }
401}
402
403impl<F: Float + NumAssign> LinearOperator<F> for SumOperator<F> {
404    fn shape(&self) -> (usize, usize) {
405        self.a.shape()
406    }
407
408    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
409        let a_result = self.a.matvec(x)?;
410        let b_result = self.b.matvec(x)?;
411        Ok(a_result
412            .iter()
413            .zip(&b_result)
414            .map(|(&a, &b)| a + b)
415            .collect())
416    }
417
418    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
419        if !self.a.has_adjoint() || !self.b.has_adjoint() {
420            return Err(crate::error::SparseError::OperationNotSupported(
421                "adjoint not supported for one or both operators".to_string(),
422            ));
423        }
424        let a_result = self.a.rmatvec(x)?;
425        let b_result = self.b.rmatvec(x)?;
426        Ok(a_result
427            .iter()
428            .zip(&b_result)
429            .map(|(&a, &b)| a + b)
430            .collect())
431    }
432
433    fn has_adjoint(&self) -> bool {
434        self.a.has_adjoint() && self.b.has_adjoint()
435    }
436}
437
438/// Product of two linear operators: (A * B) * x = A * (B * x)
439pub struct ProductOperator<F> {
440    a: Box<dyn LinearOperator<F>>,
441    b: Box<dyn LinearOperator<F>>,
442}
443
444impl<F: Float + NumAssign> ProductOperator<F> {
445    /// Create a new product operator
446    #[allow(dead_code)]
447    pub fn new(a: Box<dyn LinearOperator<F>>, b: Box<dyn LinearOperator<F>>) -> SparseResult<Self> {
448        let (_a_rows, a_cols) = a.shape();
449        let (b_rows, b_cols) = b.shape();
450        if a_cols != b_rows {
451            return Err(crate::error::SparseError::DimensionMismatch {
452                expected: a_cols,
453                found: b_rows,
454            });
455        }
456        Ok(Self { a, b })
457    }
458}
459
460impl<F: Float + NumAssign> LinearOperator<F> for ProductOperator<F> {
461    fn shape(&self) -> (usize, usize) {
462        let (a_rows, _) = self.a.shape();
463        let (_, b_cols) = self.b.shape();
464        (a_rows, b_cols)
465    }
466
467    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
468        let b_result = self.b.matvec(x)?;
469        self.a.matvec(&b_result)
470    }
471
472    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
473        if !self.a.has_adjoint() || !self.b.has_adjoint() {
474            return Err(crate::error::SparseError::OperationNotSupported(
475                "adjoint not supported for one or both operators".to_string(),
476            ));
477        }
478        // (A * B)^H = B^H * A^H
479        let a_result = self.a.rmatvec(x)?;
480        self.b.rmatvec(&a_result)
481    }
482
483    fn has_adjoint(&self) -> bool {
484        self.a.has_adjoint() && self.b.has_adjoint()
485    }
486}
487
488/// Function-based linear operator for matrix-free implementations
489pub struct FunctionOperator<F> {
490    shape: (usize, usize),
491    matvec_fn: MatVecFn<F>,
492    rmatvec_fn: Option<MatVecFn<F>>,
493}
494
495impl<F: Float + 'static> FunctionOperator<F> {
496    /// Create a new function-based operator
497    #[allow(dead_code)]
498    pub fn new<MV, RMV>(shape: (usize, usize), matvec_fn: MV, rmatvec_fn: Option<RMV>) -> Self
499    where
500        MV: Fn(&[F]) -> SparseResult<Vec<F>> + Send + Sync + 'static,
501        RMV: Fn(&[F]) -> SparseResult<Vec<F>> + Send + Sync + 'static,
502    {
503        Self {
504            shape,
505            matvec_fn: Box::new(matvec_fn),
506            rmatvec_fn: rmatvec_fn
507                .map(|f| Box::new(f) as Box<dyn Fn(&[F]) -> SparseResult<Vec<F>> + Send + Sync>),
508        }
509    }
510
511    /// Create a matrix-free operator from a function
512    #[allow(dead_code)]
513    pub fn from_function<FMv>(shape: (usize, usize), matvec_fn: FMv) -> Self
514    where
515        FMv: Fn(&[F]) -> SparseResult<Vec<F>> + Send + Sync + 'static,
516    {
517        Self::new(shape, matvec_fn, None::<fn(&[F]) -> SparseResult<Vec<F>>>)
518    }
519}
520
521impl<F: Float> LinearOperator<F> for FunctionOperator<F> {
522    fn shape(&self) -> (usize, usize) {
523        self.shape
524    }
525
526    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
527        (self.matvec_fn)(x)
528    }
529
530    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
531        match &self.rmatvec_fn {
532            Some(f) => f(x),
533            None => Err(SparseError::OperationNotSupported(
534                "adjoint not implemented for this function operator".to_string(),
535            )),
536        }
537    }
538
539    fn has_adjoint(&self) -> bool {
540        self.rmatvec_fn.is_some()
541    }
542}
543
544/// Inverse operator: A^(-1)
545/// Note: This is a conceptual operator, actual implementation depends on the specific matrix
546pub struct InverseOperator<F> {
547    original: Box<dyn LinearOperator<F>>,
548    solver_fn: SolverFn<F>,
549}
550
551impl<F: Float> InverseOperator<F> {
552    /// Create a new inverse operator with a custom solver function
553    #[allow(dead_code)]
554    pub fn new<S>(original: Box<dyn LinearOperator<F>>, solver_fn: S) -> SparseResult<Self>
555    where
556        S: Fn(&[F]) -> SparseResult<Vec<F>> + Send + Sync + 'static,
557    {
558        let (rows, cols) = original.shape();
559        if rows != cols {
560            return Err(SparseError::ValueError(
561                "Cannot invert non-square operator".to_string(),
562            ));
563        }
564
565        Ok(Self {
566            original,
567            solver_fn: Box::new(solver_fn),
568        })
569    }
570}
571
572impl<F: Float> LinearOperator<F> for InverseOperator<F> {
573    fn shape(&self) -> (usize, usize) {
574        self.original.shape()
575    }
576
577    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
578        // A^(-1) * x is equivalent to solving A * y = x for y
579        (self.solver_fn)(x)
580    }
581
582    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
583        // (A^(-1))^H = (A^H)^(-1)
584        // So we need to solve A^H * y = x for y
585        if !self.original.has_adjoint() {
586            return Err(SparseError::OperationNotSupported(
587                "adjoint not supported for original operator".to_string(),
588            ));
589        }
590
591        // This is a conceptual implementation - in practice, you'd need
592        // a solver for the adjoint system
593        Err(SparseError::OperationNotSupported(
594            "adjoint of inverse operator not yet implemented".to_string(),
595        ))
596    }
597
598    fn has_adjoint(&self) -> bool {
599        false // Simplified for now
600    }
601}
602
603/// Transpose operator: A^T
604pub struct TransposeOperator<F> {
605    original: Box<dyn LinearOperator<F>>,
606}
607
608impl<F: Float + NumAssign> TransposeOperator<F> {
609    /// Create a new transpose operator
610    pub fn new(original: Box<dyn LinearOperator<F>>) -> Self {
611        Self { original }
612    }
613}
614
615impl<F: Float + NumAssign> LinearOperator<F> for TransposeOperator<F> {
616    fn shape(&self) -> (usize, usize) {
617        let (rows, cols) = self.original.shape();
618        (cols, rows) // Transpose dimensions
619    }
620
621    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
622        // A^T * x = (A^H * x) for real matrices
623        self.original.rmatvec(x)
624    }
625
626    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
627        // (A^T)^H = A for real matrices
628        self.original.matvec(x)
629    }
630
631    fn has_adjoint(&self) -> bool {
632        true // Transpose always has adjoint
633    }
634}
635
636/// Adjoint operator: A^H (Hermitian transpose)
637pub struct AdjointOperator<F> {
638    original: Box<dyn LinearOperator<F>>,
639}
640
641impl<F: Float + NumAssign> AdjointOperator<F> {
642    /// Create a new adjoint operator
643    pub fn new(original: Box<dyn LinearOperator<F>>) -> SparseResult<Self> {
644        if !original.has_adjoint() {
645            return Err(SparseError::OperationNotSupported(
646                "Original operator does not support adjoint operations".to_string(),
647            ));
648        }
649        Ok(Self { original })
650    }
651}
652
653impl<F: Float + NumAssign> LinearOperator<F> for AdjointOperator<F> {
654    fn shape(&self) -> (usize, usize) {
655        let (rows, cols) = self.original.shape();
656        (cols, rows) // Adjoint transposes dimensions
657    }
658
659    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
660        self.original.rmatvec(x)
661    }
662
663    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
664        self.original.matvec(x)
665    }
666
667    fn has_adjoint(&self) -> bool {
668        true
669    }
670}
671
672/// Difference operator: A - B
673pub struct DifferenceOperator<F> {
674    a: Box<dyn LinearOperator<F>>,
675    b: Box<dyn LinearOperator<F>>,
676}
677
678impl<F: Float + NumAssign> DifferenceOperator<F> {
679    /// Create a new difference operator
680    pub fn new(a: Box<dyn LinearOperator<F>>, b: Box<dyn LinearOperator<F>>) -> SparseResult<Self> {
681        if a.shape() != b.shape() {
682            return Err(SparseError::ShapeMismatch {
683                expected: a.shape(),
684                found: b.shape(),
685            });
686        }
687        Ok(Self { a, b })
688    }
689}
690
691impl<F: Float + NumAssign> LinearOperator<F> for DifferenceOperator<F> {
692    fn shape(&self) -> (usize, usize) {
693        self.a.shape()
694    }
695
696    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
697        let a_result = self.a.matvec(x)?;
698        let b_result = self.b.matvec(x)?;
699        Ok(a_result
700            .iter()
701            .zip(&b_result)
702            .map(|(&a, &b)| a - b)
703            .collect())
704    }
705
706    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
707        if !self.a.has_adjoint() || !self.b.has_adjoint() {
708            return Err(SparseError::OperationNotSupported(
709                "adjoint not supported for one or both operators".to_string(),
710            ));
711        }
712        let a_result = self.a.rmatvec(x)?;
713        let b_result = self.b.rmatvec(x)?;
714        Ok(a_result
715            .iter()
716            .zip(&b_result)
717            .map(|(&a, &b)| a - b)
718            .collect())
719    }
720
721    fn has_adjoint(&self) -> bool {
722        self.a.has_adjoint() && self.b.has_adjoint()
723    }
724}
725
726/// Scaled operator: alpha * A
727pub struct ScaledOperator<F> {
728    alpha: F,
729    operator: Box<dyn LinearOperator<F>>,
730}
731
732impl<F: Float + NumAssign> ScaledOperator<F> {
733    /// Create a new scaled operator
734    pub fn new(alpha: F, operator: Box<dyn LinearOperator<F>>) -> Self {
735        Self { alpha, operator }
736    }
737}
738
739impl<F: Float + NumAssign> LinearOperator<F> for ScaledOperator<F> {
740    fn shape(&self) -> (usize, usize) {
741        self.operator.shape()
742    }
743
744    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
745        let result = self.operator.matvec(x)?;
746        Ok(result.iter().map(|&val| self.alpha * val).collect())
747    }
748
749    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
750        if !self.operator.has_adjoint() {
751            return Err(SparseError::OperationNotSupported(
752                "adjoint not supported for underlying operator".to_string(),
753            ));
754        }
755        let result = self.operator.rmatvec(x)?;
756        Ok(result.iter().map(|&val| self.alpha * val).collect())
757    }
758
759    fn has_adjoint(&self) -> bool {
760        self.operator.has_adjoint()
761    }
762}
763
764/// Chain/composition of multiple operators: A_n * A_(n-1) * ... * A_1
765pub struct ChainOperator<F> {
766    operators: Vec<Box<dyn LinearOperator<F>>>,
767    totalshape: (usize, usize),
768}
769
770impl<F: Float + NumAssign> ChainOperator<F> {
771    /// Create a new chain operator from a list of operators
772    /// Operators are applied from right to left (like function composition)
773    #[allow(dead_code)]
774    pub fn new(operators: Vec<Box<dyn LinearOperator<F>>>) -> SparseResult<Self> {
775        if operators.is_empty() {
776            return Err(SparseError::ValueError(
777                "Cannot create chain with no operators".to_string(),
778            ));
779        }
780
781        // Check dimension compatibility
782        #[allow(clippy::needless_range_loop)]
783        for i in 0..operators.len() - 1 {
784            let (_, a_cols) = operators[i].shape();
785            let (b_rows, _) = operators[i + 1].shape();
786            if a_cols != b_rows {
787                return Err(SparseError::DimensionMismatch {
788                    expected: a_cols,
789                    found: b_rows,
790                });
791            }
792        }
793
794        let (first_rows, _) = operators[0].shape();
795        let (_, last_cols) = operators.last().unwrap().shape();
796        let totalshape = (first_rows, last_cols);
797
798        Ok(Self {
799            operators,
800            totalshape,
801        })
802    }
803}
804
805impl<F: Float + NumAssign> LinearOperator<F> for ChainOperator<F> {
806    fn shape(&self) -> (usize, usize) {
807        self.totalshape
808    }
809
810    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
811        let mut result = x.to_vec();
812        // Apply operators from right to left
813        for op in self.operators.iter().rev() {
814            result = op.matvec(&result)?;
815        }
816        Ok(result)
817    }
818
819    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
820        // Check if all operators support adjoint
821        for op in &self.operators {
822            if !op.has_adjoint() {
823                return Err(SparseError::OperationNotSupported(
824                    "adjoint not supported for all operators in chain".to_string(),
825                ));
826            }
827        }
828
829        let mut result = x.to_vec();
830        // Apply adjoints from left to right (reverse order)
831        for op in &self.operators {
832            result = op.rmatvec(&result)?;
833        }
834        Ok(result)
835    }
836
837    fn has_adjoint(&self) -> bool {
838        self.operators.iter().all(|op| op.has_adjoint())
839    }
840}
841
842/// Power operator: A^n (for positive integer n)
843pub struct PowerOperator<F> {
844    operator: Box<dyn LinearOperator<F>>,
845    power: usize,
846}
847
848impl<F: Float + NumAssign> PowerOperator<F> {
849    /// Create a new power operator
850    pub fn new(operator: Box<dyn LinearOperator<F>>, power: usize) -> SparseResult<Self> {
851        let (rows, cols) = operator.shape();
852        if rows != cols {
853            return Err(SparseError::ValueError(
854                "Can only compute powers of square operators".to_string(),
855            ));
856        }
857        if power == 0 {
858            return Err(SparseError::ValueError(
859                "Power must be positive".to_string(),
860            ));
861        }
862        Ok(Self { operator, power })
863    }
864}
865
866impl<F: Float + NumAssign> LinearOperator<F> for PowerOperator<F> {
867    fn shape(&self) -> (usize, usize) {
868        self.operator.shape()
869    }
870
871    fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
872        let mut result = x.to_vec();
873        for _ in 0..self.power {
874            result = self.operator.matvec(&result)?;
875        }
876        Ok(result)
877    }
878
879    fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
880        if !self.operator.has_adjoint() {
881            return Err(SparseError::OperationNotSupported(
882                "adjoint not supported for underlying operator".to_string(),
883            ));
884        }
885        let mut result = x.to_vec();
886        for _ in 0..self.power {
887            result = self.operator.rmatvec(&result)?;
888        }
889        Ok(result)
890    }
891
892    fn has_adjoint(&self) -> bool {
893        self.operator.has_adjoint()
894    }
895}
896
897/// Enhanced LinearOperator trait with composition methods
898#[allow(dead_code)]
899pub trait LinearOperatorExt<F: Float + NumAssign>: LinearOperator<F> {
900    /// Add this operator with another
901    fn add(&self, other: Box<dyn LinearOperator<F>>) -> SparseResult<Box<dyn LinearOperator<F>>>;
902
903    /// Subtract another operator from this one
904    fn sub(&self, other: Box<dyn LinearOperator<F>>) -> SparseResult<Box<dyn LinearOperator<F>>>;
905
906    /// Multiply this operator with another (composition)
907    fn mul(&self, other: Box<dyn LinearOperator<F>>) -> SparseResult<Box<dyn LinearOperator<F>>>;
908
909    /// Scale this operator by a scalar
910    fn scale(&self, alpha: F) -> Box<dyn LinearOperator<F>>;
911
912    /// Transpose this operator
913    fn transpose(&self) -> Box<dyn LinearOperator<F>>;
914
915    /// Adjoint of this operator
916    fn adjoint(&self) -> SparseResult<Box<dyn LinearOperator<F>>>;
917
918    /// Power of this operator
919    fn pow(&self, n: usize) -> SparseResult<Box<dyn LinearOperator<F>>>;
920}
921
922// Specific implementations for each cloneable operator type
923macro_rules! impl_linear_operator_ext {
924    ($typ:ty) => {
925        impl<F: Float + NumAssign + Copy + 'static> LinearOperatorExt<F> for $typ {
926            fn add(
927                &self,
928                other: Box<dyn LinearOperator<F>>,
929            ) -> SparseResult<Box<dyn LinearOperator<F>>> {
930                let self_box = Box::new(self.clone());
931                Ok(Box::new(SumOperator::new(self_box, other)?))
932            }
933
934            fn sub(
935                &self,
936                other: Box<dyn LinearOperator<F>>,
937            ) -> SparseResult<Box<dyn LinearOperator<F>>> {
938                let self_box = Box::new(self.clone());
939                Ok(Box::new(DifferenceOperator::new(self_box, other)?))
940            }
941
942            fn mul(
943                &self,
944                other: Box<dyn LinearOperator<F>>,
945            ) -> SparseResult<Box<dyn LinearOperator<F>>> {
946                let self_box = Box::new(self.clone());
947                Ok(Box::new(ProductOperator::new(self_box, other)?))
948            }
949
950            fn scale(&self, alpha: F) -> Box<dyn LinearOperator<F>> {
951                let self_box = Box::new(self.clone());
952                Box::new(ScaledOperator::new(alpha, self_box))
953            }
954
955            fn transpose(&self) -> Box<dyn LinearOperator<F>> {
956                let self_box = Box::new(self.clone());
957                Box::new(TransposeOperator::new(self_box))
958            }
959
960            fn adjoint(&self) -> SparseResult<Box<dyn LinearOperator<F>>> {
961                let self_box = Box::new(self.clone());
962                Ok(Box::new(AdjointOperator::new(self_box)?))
963            }
964
965            fn pow(&self, n: usize) -> SparseResult<Box<dyn LinearOperator<F>>> {
966                let self_box = Box::new(self.clone());
967                Ok(Box::new(PowerOperator::new(self_box, n)?))
968            }
969        }
970    };
971}
972
973// Apply the macro to all cloneable operator types
974impl_linear_operator_ext!(IdentityOperator<F>);
975impl_linear_operator_ext!(ScaledIdentityOperator<F>);
976impl_linear_operator_ext!(DiagonalOperator<F>);
977
978// ZeroOperator needs special implementation with SparseElement bound
979impl<F: Float + NumAssign + Copy + SparseElement + 'static> LinearOperatorExt<F>
980    for ZeroOperator<F>
981{
982    fn add(&self, other: Box<dyn LinearOperator<F>>) -> SparseResult<Box<dyn LinearOperator<F>>> {
983        let self_box = Box::new(self.clone());
984        Ok(Box::new(SumOperator::new(self_box, other)?))
985    }
986
987    fn sub(&self, other: Box<dyn LinearOperator<F>>) -> SparseResult<Box<dyn LinearOperator<F>>> {
988        let self_box = Box::new(self.clone());
989        Ok(Box::new(DifferenceOperator::new(self_box, other)?))
990    }
991
992    fn mul(&self, other: Box<dyn LinearOperator<F>>) -> SparseResult<Box<dyn LinearOperator<F>>> {
993        let self_box = Box::new(self.clone());
994        Ok(Box::new(ProductOperator::new(self_box, other)?))
995    }
996
997    fn scale(&self, alpha: F) -> Box<dyn LinearOperator<F>> {
998        let self_box = Box::new(self.clone());
999        Box::new(ScaledOperator::new(alpha, self_box))
1000    }
1001
1002    fn transpose(&self) -> Box<dyn LinearOperator<F>> {
1003        let self_box = Box::new(self.clone());
1004        Box::new(TransposeOperator::new(self_box))
1005    }
1006
1007    fn adjoint(&self) -> SparseResult<Box<dyn LinearOperator<F>>> {
1008        let self_box = Box::new(self.clone());
1009        Ok(Box::new(AdjointOperator::new(self_box)?))
1010    }
1011
1012    fn pow(&self, n: usize) -> SparseResult<Box<dyn LinearOperator<F>>> {
1013        let self_box = Box::new(self.clone());
1014        Ok(Box::new(PowerOperator::new(self_box, n)?))
1015    }
1016}
1017
1018/// Utility functions for operator composition
1019/// Add two operators: left + right
1020#[allow(dead_code)]
1021pub fn add_operators<F: Float + NumAssign + 'static>(
1022    left: Box<dyn LinearOperator<F>>,
1023    right: Box<dyn LinearOperator<F>>,
1024) -> SparseResult<Box<dyn LinearOperator<F>>> {
1025    Ok(Box::new(SumOperator::new(left, right)?))
1026}
1027
1028/// Subtract two operators: left - right
1029#[allow(dead_code)]
1030pub fn subtract_operators<F: Float + NumAssign + 'static>(
1031    left: Box<dyn LinearOperator<F>>,
1032    right: Box<dyn LinearOperator<F>>,
1033) -> SparseResult<Box<dyn LinearOperator<F>>> {
1034    Ok(Box::new(DifferenceOperator::new(left, right)?))
1035}
1036
1037/// Multiply two operators: left * right  
1038#[allow(dead_code)]
1039pub fn multiply_operators<F: Float + NumAssign + 'static>(
1040    left: Box<dyn LinearOperator<F>>,
1041    right: Box<dyn LinearOperator<F>>,
1042) -> SparseResult<Box<dyn LinearOperator<F>>> {
1043    Ok(Box::new(ProductOperator::new(left, right)?))
1044}
1045
1046/// Scale an operator: alpha * operator
1047#[allow(dead_code)]
1048pub fn scale_operator<F: Float + NumAssign + 'static>(
1049    alpha: F,
1050    operator: Box<dyn LinearOperator<F>>,
1051) -> Box<dyn LinearOperator<F>> {
1052    Box::new(ScaledOperator::new(alpha, operator))
1053}
1054
1055/// Transpose an operator: A^T
1056#[allow(dead_code)]
1057pub fn transpose_operator<F: Float + NumAssign + 'static>(
1058    operator: Box<dyn LinearOperator<F>>,
1059) -> Box<dyn LinearOperator<F>> {
1060    Box::new(TransposeOperator::new(operator))
1061}
1062
1063/// Adjoint of an operator: A^H
1064#[allow(dead_code)]
1065pub fn adjoint_operator<F: Float + NumAssign + 'static>(
1066    operator: Box<dyn LinearOperator<F>>,
1067) -> SparseResult<Box<dyn LinearOperator<F>>> {
1068    Ok(Box::new(AdjointOperator::new(operator)?))
1069}
1070
1071/// Compose multiple operators: A_n * A_(n-1) * ... * A_1
1072#[allow(dead_code)]
1073pub fn compose_operators<F: Float + NumAssign + 'static>(
1074    operators: Vec<Box<dyn LinearOperator<F>>>,
1075) -> SparseResult<Box<dyn LinearOperator<F>>> {
1076    Ok(Box::new(ChainOperator::new(operators)?))
1077}
1078
1079/// Power of an operator: A^n
1080#[allow(dead_code)]
1081pub fn power_operator<F: Float + NumAssign + 'static>(
1082    operator: Box<dyn LinearOperator<F>>,
1083    n: usize,
1084) -> SparseResult<Box<dyn LinearOperator<F>>> {
1085    Ok(Box::new(PowerOperator::new(operator, n)?))
1086}
1087
1088#[cfg(test)]
1089mod tests {
1090    use super::*;
1091
1092    #[test]
1093    fn test_identity_operator() {
1094        let op = IdentityOperator::<f64>::new(3);
1095        let x = vec![1.0, 2.0, 3.0];
1096        let y = op.matvec(&x).unwrap();
1097        assert_eq!(x, y);
1098    }
1099
1100    #[test]
1101    fn test_scaled_identity_operator() {
1102        let op = ScaledIdentityOperator::new(3, 2.0);
1103        let x = vec![1.0, 2.0, 3.0];
1104        let y = op.matvec(&x).unwrap();
1105        assert_eq!(y, vec![2.0, 4.0, 6.0]);
1106    }
1107
1108    #[test]
1109    fn test_diagonal_operator() {
1110        let op = DiagonalOperator::new(vec![2.0, 3.0, 4.0]);
1111        let x = vec![1.0, 2.0, 3.0];
1112        let y = op.matvec(&x).unwrap();
1113        assert_eq!(y, vec![2.0, 6.0, 12.0]);
1114    }
1115
1116    #[test]
1117    fn test_zero_operator() {
1118        let op = ZeroOperator::<f64>::new(3, 3);
1119        let x = vec![1.0, 2.0, 3.0];
1120        let y = op.matvec(&x).unwrap();
1121        assert_eq!(y, vec![0.0, 0.0, 0.0]);
1122    }
1123
1124    #[test]
1125    fn test_sum_operator() {
1126        let id = Box::new(IdentityOperator::<f64>::new(3));
1127        let scaled = Box::new(ScaledIdentityOperator::new(3, 2.0));
1128        let sum = SumOperator::new(id, scaled).unwrap();
1129        let x = vec![1.0, 2.0, 3.0];
1130        let y = sum.matvec(&x).unwrap();
1131        assert_eq!(y, vec![3.0, 6.0, 9.0]); // (I + 2I) * x = 3x
1132    }
1133
1134    #[test]
1135    fn test_product_operator() {
1136        let id = Box::new(IdentityOperator::<f64>::new(3));
1137        let scaled = Box::new(ScaledIdentityOperator::new(3, 2.0));
1138        let product = ProductOperator::new(scaled, id).unwrap();
1139        let x = vec![1.0, 2.0, 3.0];
1140        let y = product.matvec(&x).unwrap();
1141        assert_eq!(y, vec![2.0, 4.0, 6.0]); // (2I * I) * x = 2x
1142    }
1143
1144    #[test]
1145    fn test_difference_operator() {
1146        let scaled_3 = Box::new(ScaledIdentityOperator::new(3, 3.0));
1147        let scaled_2 = Box::new(ScaledIdentityOperator::new(3, 2.0));
1148        let diff = DifferenceOperator::new(scaled_3, scaled_2).unwrap();
1149        let x = vec![1.0, 2.0, 3.0];
1150        let y = diff.matvec(&x).unwrap();
1151        assert_eq!(y, vec![1.0, 2.0, 3.0]); // (3I - 2I) * x = I * x = x
1152    }
1153
1154    #[test]
1155    fn test_scaled_operator() {
1156        let id = Box::new(IdentityOperator::<f64>::new(3));
1157        let scaled = ScaledOperator::new(5.0, id);
1158        let x = vec![1.0, 2.0, 3.0];
1159        let y = scaled.matvec(&x).unwrap();
1160        assert_eq!(y, vec![5.0, 10.0, 15.0]); // 5 * I * x = 5x
1161    }
1162
1163    #[test]
1164    fn test_transpose_operator() {
1165        let diag = Box::new(DiagonalOperator::new(vec![2.0, 3.0, 4.0]));
1166        let transpose = TransposeOperator::new(diag);
1167        let x = vec![1.0, 2.0, 3.0];
1168        let y = transpose.matvec(&x).unwrap();
1169        // For diagonal matrices, transpose equals original
1170        assert_eq!(y, vec![2.0, 6.0, 12.0]);
1171    }
1172
1173    #[test]
1174    fn test_adjoint_operator() {
1175        let diag = Box::new(DiagonalOperator::new(vec![2.0, 3.0, 4.0]));
1176        let adjoint = AdjointOperator::new(diag).unwrap();
1177        let x = vec![1.0, 2.0, 3.0];
1178        let y = adjoint.matvec(&x).unwrap();
1179        // For real diagonal matrices, adjoint equals original
1180        assert_eq!(y, vec![2.0, 6.0, 12.0]);
1181    }
1182
1183    #[test]
1184    fn test_chain_operator() {
1185        let op1 = Box::new(ScaledIdentityOperator::new(3, 2.0));
1186        let op2 = Box::new(ScaledIdentityOperator::new(3, 3.0));
1187        let chain = ChainOperator::new(vec![op1, op2]).unwrap();
1188        let x = vec![1.0, 2.0, 3.0];
1189        let y = chain.matvec(&x).unwrap();
1190        // Chain applies from right to left: (2I) * (3I) * x = 6I * x = 6x
1191        assert_eq!(y, vec![6.0, 12.0, 18.0]);
1192    }
1193
1194    #[test]
1195    fn test_power_operator() {
1196        let scaled = Box::new(ScaledIdentityOperator::new(3, 2.0));
1197        let power = PowerOperator::new(scaled, 3).unwrap();
1198        let x = vec![1.0, 2.0, 3.0];
1199        let y = power.matvec(&x).unwrap();
1200        // (2I)^3 * x = 8I * x = 8x
1201        assert_eq!(y, vec![8.0, 16.0, 24.0]);
1202    }
1203
1204    #[test]
1205    fn test_composition_utility_functions() {
1206        let id = Box::new(IdentityOperator::<f64>::new(3));
1207        let scaled = Box::new(ScaledIdentityOperator::new(3, 2.0));
1208
1209        // Test add_operators
1210        let sum = add_operators(id.clone(), scaled.clone()).unwrap();
1211        let x = vec![1.0, 2.0, 3.0];
1212        let y = sum.matvec(&x).unwrap();
1213        assert_eq!(y, vec![3.0, 6.0, 9.0]); // (I + 2I) * x = 3x
1214
1215        // Test subtract_operators
1216        let diff = subtract_operators(scaled.clone(), id.clone()).unwrap();
1217        let y2 = diff.matvec(&x).unwrap();
1218        assert_eq!(y2, vec![1.0, 2.0, 3.0]); // (2I - I) * x = x
1219
1220        // Test multiply_operators
1221        let product = multiply_operators(scaled.clone(), id.clone()).unwrap();
1222        let y3 = product.matvec(&x).unwrap();
1223        assert_eq!(y3, vec![2.0, 4.0, 6.0]); // (2I * I) * x = 2x
1224
1225        // Test scale_operator
1226        let scaled_op = scale_operator(3.0, id.clone());
1227        let y4 = scaled_op.matvec(&x).unwrap();
1228        assert_eq!(y4, vec![3.0, 6.0, 9.0]); // 3 * I * x = 3x
1229
1230        // Test transpose_operator
1231        let transpose = transpose_operator(scaled.clone());
1232        let y5 = transpose.matvec(&x).unwrap();
1233        assert_eq!(y5, vec![2.0, 4.0, 6.0]); // (2I)^T * x = 2I * x = 2x
1234
1235        // Test compose_operators
1236        let ops: Vec<Box<dyn LinearOperator<f64>>> = vec![scaled.clone(), id.clone()];
1237        let composed = compose_operators(ops).unwrap();
1238        let y6 = composed.matvec(&x).unwrap();
1239        assert_eq!(y6, vec![2.0, 4.0, 6.0]); // (2I * I) * x = 2x
1240
1241        // Test power_operator
1242        let power = power_operator(scaled.clone(), 2).unwrap();
1243        let y7 = power.matvec(&x).unwrap();
1244        assert_eq!(y7, vec![4.0, 8.0, 12.0]); // (2I)^2 * x = 4I * x = 4x
1245    }
1246
1247    #[test]
1248    fn test_dimension_mismatch_errors() {
1249        let op1 = Box::new(IdentityOperator::<f64>::new(3));
1250        let op2 = Box::new(IdentityOperator::<f64>::new(4));
1251
1252        // Test sum operator dimension mismatch
1253        assert!(SumOperator::new(op1.clone(), op2.clone()).is_err());
1254
1255        // Test difference operator dimension mismatch
1256        assert!(DifferenceOperator::new(op1.clone(), op2.clone()).is_err());
1257
1258        // Test product operator dimension mismatch (incompatible dimensions)
1259        let rect1 = Box::new(ZeroOperator::<f64>::new(3, 4));
1260        let rect2 = Box::new(ZeroOperator::<f64>::new(5, 3));
1261        assert!(ProductOperator::new(rect1, rect2).is_err());
1262    }
1263
1264    #[test]
1265    fn test_adjoint_not_supported_error() {
1266        // Create a function operator without adjoint support
1267        let func_op = Box::new(FunctionOperator::from_function((3, 3), |x: &[f64]| {
1268            Ok(x.to_vec())
1269        }));
1270
1271        // Attempting to create adjoint should fail
1272        assert!(AdjointOperator::new(func_op).is_err());
1273    }
1274
1275    #[test]
1276    fn test_power_operator_errors() {
1277        let rect_op = Box::new(ZeroOperator::<f64>::new(3, 4));
1278
1279        // Power of non-square operator should fail
1280        assert!(PowerOperator::new(rect_op, 2).is_err());
1281
1282        let square_op = Box::new(IdentityOperator::<f64>::new(3));
1283
1284        // Power of 0 should fail
1285        assert!(PowerOperator::new(square_op, 0).is_err());
1286    }
1287}