1#![allow(unused_variables)]
4#![allow(unused_assignments)]
5#![allow(unused_mut)]
6
7use crate::error::{SparseError, SparseResult};
8use crate::sparray::SparseArray;
9use scirs2_core::numeric::{Float, NumAssign};
10use std::fmt::Debug;
11use std::iter::Sum;
12use std::marker::PhantomData;
13
14type MatVecFn<F> = Box<dyn Fn(&[F]) -> SparseResult<Vec<F>> + Send + Sync>;
16
17type SolverFn<F> = Box<dyn Fn(&[F]) -> SparseResult<Vec<F>> + Send + Sync>;
19
20pub trait LinearOperator<F: Float> {
25 fn shape(&self) -> (usize, usize);
27
28 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>>;
30
31 fn matmat(&self, x: &[Vec<F>]) -> SparseResult<Vec<Vec<F>>> {
34 let mut result = Vec::new();
35 for col in x {
36 result.push(self.matvec(col)?);
37 }
38 Ok(result)
39 }
40
41 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
44 Err(crate::error::SparseError::OperationNotSupported(
45 "adjoint not implemented for this operator".to_string(),
46 ))
47 }
48
49 fn rmatmat(&self, x: &[Vec<F>]) -> SparseResult<Vec<Vec<F>>> {
52 let mut result = Vec::new();
53 for col in x {
54 result.push(self.rmatvec(col)?);
55 }
56 Ok(result)
57 }
58
59 fn has_adjoint(&self) -> bool {
61 false
62 }
63}
64
65#[derive(Clone)]
67pub struct IdentityOperator<F> {
68 size: usize,
69 phantom: PhantomData<F>,
70}
71
72impl<F> IdentityOperator<F> {
73 pub fn new(size: usize) -> Self {
75 Self {
76 size,
77 phantom: PhantomData,
78 }
79 }
80}
81
82impl<F: Float> LinearOperator<F> for IdentityOperator<F> {
83 fn shape(&self) -> (usize, usize) {
84 (self.size, self.size)
85 }
86
87 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
88 if x.len() != self.size {
89 return Err(crate::error::SparseError::DimensionMismatch {
90 expected: self.size,
91 found: x.len(),
92 });
93 }
94 Ok(x.to_vec())
95 }
96
97 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
98 self.matvec(x)
99 }
100
101 fn has_adjoint(&self) -> bool {
102 true
103 }
104}
105
106#[derive(Clone)]
108pub struct ScaledIdentityOperator<F> {
109 size: usize,
110 scale: F,
111}
112
113impl<F: Float> ScaledIdentityOperator<F> {
114 pub fn new(size: usize, scale: F) -> Self {
116 Self { size, scale }
117 }
118}
119
120impl<F: Float + NumAssign> LinearOperator<F> for ScaledIdentityOperator<F> {
121 fn shape(&self) -> (usize, usize) {
122 (self.size, self.size)
123 }
124
125 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
126 if x.len() != self.size {
127 return Err(crate::error::SparseError::DimensionMismatch {
128 expected: self.size,
129 found: x.len(),
130 });
131 }
132 Ok(x.iter().map(|&xi| xi * self.scale).collect())
133 }
134
135 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
136 self.matvec(x)
138 }
139
140 fn has_adjoint(&self) -> bool {
141 true
142 }
143}
144
145#[derive(Clone)]
147pub struct DiagonalOperator<F> {
148 diagonal: Vec<F>,
149}
150
151impl<F: Float> DiagonalOperator<F> {
152 pub fn new(diagonal: Vec<F>) -> Self {
154 Self { diagonal }
155 }
156
157 pub fn diagonal(&self) -> &[F] {
159 &self.diagonal
160 }
161}
162
163impl<F: Float + NumAssign> LinearOperator<F> for DiagonalOperator<F> {
164 fn shape(&self) -> (usize, usize) {
165 let n = self.diagonal.len();
166 (n, n)
167 }
168
169 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
170 if x.len() != self.diagonal.len() {
171 return Err(crate::error::SparseError::DimensionMismatch {
172 expected: self.diagonal.len(),
173 found: x.len(),
174 });
175 }
176 Ok(x.iter()
177 .zip(&self.diagonal)
178 .map(|(&xi, &di)| xi * di)
179 .collect())
180 }
181
182 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
183 self.matvec(x)
185 }
186
187 fn has_adjoint(&self) -> bool {
188 true
189 }
190}
191
192#[derive(Clone)]
194pub struct ZeroOperator<F> {
195 shape: (usize, usize),
196 _phantom: PhantomData<F>,
197}
198
199impl<F> ZeroOperator<F> {
200 #[allow(dead_code)]
202 pub fn new(rows: usize, cols: usize) -> Self {
203 Self {
204 shape: (rows, cols),
205 _phantom: PhantomData,
206 }
207 }
208}
209
210impl<F: Float> LinearOperator<F> for ZeroOperator<F> {
211 fn shape(&self) -> (usize, usize) {
212 self.shape
213 }
214
215 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
216 if x.len() != self.shape.1 {
217 return Err(crate::error::SparseError::DimensionMismatch {
218 expected: self.shape.1,
219 found: x.len(),
220 });
221 }
222 Ok(vec![F::zero(); self.shape.0])
223 }
224
225 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
226 if x.len() != self.shape.0 {
227 return Err(crate::error::SparseError::DimensionMismatch {
228 expected: self.shape.0,
229 found: x.len(),
230 });
231 }
232 Ok(vec![F::zero(); self.shape.1])
233 }
234
235 fn has_adjoint(&self) -> bool {
236 true
237 }
238}
239
240pub trait AsLinearOperator<F: Float> {
242 fn as_linear_operator(&self) -> Box<dyn LinearOperator<F>>;
244}
245
246pub struct MatrixLinearOperator<F, M> {
248 matrix: M,
249 phantom: PhantomData<F>,
250}
251
252impl<F, M> MatrixLinearOperator<F, M> {
253 pub fn new(matrix: M) -> Self {
255 Self {
256 matrix,
257 phantom: PhantomData,
258 }
259 }
260}
261
262use crate::csr::CsrMatrix;
264
265impl<F: Float + NumAssign + Sum + 'static + Debug> LinearOperator<F>
266 for MatrixLinearOperator<F, CsrMatrix<F>>
267{
268 fn shape(&self) -> (usize, usize) {
269 (self.matrix.rows(), self.matrix.cols())
270 }
271
272 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
273 if x.len() != self.matrix.cols() {
274 return Err(SparseError::DimensionMismatch {
275 expected: self.matrix.cols(),
276 found: x.len(),
277 });
278 }
279
280 let mut result = vec![F::zero(); self.matrix.rows()];
282 for (row, result_elem) in result.iter_mut().enumerate().take(self.matrix.rows()) {
283 let row_range = self.matrix.row_range(row);
284 let row_indices = &self.matrix.colindices()[row_range.clone()];
285 let row_data = &self.matrix.data[row_range];
286
287 let mut sum = F::zero();
288 for (col_idx, &col) in row_indices.iter().enumerate() {
289 sum += row_data[col_idx] * x[col];
290 }
291 *result_elem = sum;
292 }
293 Ok(result)
294 }
295
296 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
297 let transposed = self.matrix.transpose();
299 MatrixLinearOperator::new(transposed).matvec(x)
300 }
301
302 fn has_adjoint(&self) -> bool {
303 true
304 }
305}
306
307use crate::csr_array::CsrArray;
309
310impl<F: Float + NumAssign + Sum + 'static + Debug> LinearOperator<F>
311 for MatrixLinearOperator<F, CsrArray<F>>
312{
313 fn shape(&self) -> (usize, usize) {
314 self.matrix.shape()
315 }
316
317 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
318 if x.len() != self.matrix.shape().1 {
319 return Err(SparseError::DimensionMismatch {
320 expected: self.matrix.shape().1,
321 found: x.len(),
322 });
323 }
324
325 use scirs2_core::ndarray::Array1;
326 let x_array = Array1::from_vec(x.to_vec());
327 let result = self.matrix.dot_vector(&x_array.view())?;
328 Ok(result.to_vec())
329 }
330
331 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
332 if x.len() != self.matrix.shape().0 {
334 return Err(SparseError::DimensionMismatch {
335 expected: self.matrix.shape().0,
336 found: x.len(),
337 });
338 }
339
340 let mut result = vec![F::zero(); self.matrix.shape().1];
341
342 for (row_idx, &x_val) in x.iter().enumerate() {
344 if x_val != F::zero() {
345 let row_start = self.matrix.get_indptr()[row_idx];
347 let row_end = self.matrix.get_indptr()[row_idx + 1];
348
349 for idx in row_start..row_end {
350 let col_idx = self.matrix.get_indices()[idx];
351 let data_val = self.matrix.get_data()[idx];
352 result[col_idx] += data_val * x_val;
353 }
354 }
355 }
356
357 Ok(result)
358 }
359
360 fn has_adjoint(&self) -> bool {
361 true
362 }
363}
364
365impl<F: Float + NumAssign + Sum + 'static + Debug> AsLinearOperator<F> for CsrMatrix<F> {
366 fn as_linear_operator(&self) -> Box<dyn LinearOperator<F>> {
367 Box::new(MatrixLinearOperator::new(self.clone()))
368 }
369}
370
371impl<F: Float + NumAssign + Sum + 'static + Debug> AsLinearOperator<F>
372 for crate::csr_array::CsrArray<F>
373{
374 fn as_linear_operator(&self) -> Box<dyn LinearOperator<F>> {
375 Box::new(MatrixLinearOperator::new(self.clone()))
376 }
377}
378
379pub struct SumOperator<F> {
382 a: Box<dyn LinearOperator<F>>,
383 b: Box<dyn LinearOperator<F>>,
384}
385
386impl<F: Float + NumAssign> SumOperator<F> {
387 #[allow(dead_code)]
389 pub fn new(a: Box<dyn LinearOperator<F>>, b: Box<dyn LinearOperator<F>>) -> SparseResult<Self> {
390 if a.shape() != b.shape() {
391 return Err(crate::error::SparseError::ShapeMismatch {
392 expected: a.shape(),
393 found: b.shape(),
394 });
395 }
396 Ok(Self { a, b })
397 }
398}
399
400impl<F: Float + NumAssign> LinearOperator<F> for SumOperator<F> {
401 fn shape(&self) -> (usize, usize) {
402 self.a.shape()
403 }
404
405 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
406 let a_result = self.a.matvec(x)?;
407 let b_result = self.b.matvec(x)?;
408 Ok(a_result
409 .iter()
410 .zip(&b_result)
411 .map(|(&a, &b)| a + b)
412 .collect())
413 }
414
415 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
416 if !self.a.has_adjoint() || !self.b.has_adjoint() {
417 return Err(crate::error::SparseError::OperationNotSupported(
418 "adjoint not supported for one or both operators".to_string(),
419 ));
420 }
421 let a_result = self.a.rmatvec(x)?;
422 let b_result = self.b.rmatvec(x)?;
423 Ok(a_result
424 .iter()
425 .zip(&b_result)
426 .map(|(&a, &b)| a + b)
427 .collect())
428 }
429
430 fn has_adjoint(&self) -> bool {
431 self.a.has_adjoint() && self.b.has_adjoint()
432 }
433}
434
435pub struct ProductOperator<F> {
437 a: Box<dyn LinearOperator<F>>,
438 b: Box<dyn LinearOperator<F>>,
439}
440
441impl<F: Float + NumAssign> ProductOperator<F> {
442 #[allow(dead_code)]
444 pub fn new(a: Box<dyn LinearOperator<F>>, b: Box<dyn LinearOperator<F>>) -> SparseResult<Self> {
445 let (_a_rows, a_cols) = a.shape();
446 let (b_rows, b_cols) = b.shape();
447 if a_cols != b_rows {
448 return Err(crate::error::SparseError::DimensionMismatch {
449 expected: a_cols,
450 found: b_rows,
451 });
452 }
453 Ok(Self { a, b })
454 }
455}
456
457impl<F: Float + NumAssign> LinearOperator<F> for ProductOperator<F> {
458 fn shape(&self) -> (usize, usize) {
459 let (a_rows, _) = self.a.shape();
460 let (_, b_cols) = self.b.shape();
461 (a_rows, b_cols)
462 }
463
464 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
465 let b_result = self.b.matvec(x)?;
466 self.a.matvec(&b_result)
467 }
468
469 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
470 if !self.a.has_adjoint() || !self.b.has_adjoint() {
471 return Err(crate::error::SparseError::OperationNotSupported(
472 "adjoint not supported for one or both operators".to_string(),
473 ));
474 }
475 let a_result = self.a.rmatvec(x)?;
477 self.b.rmatvec(&a_result)
478 }
479
480 fn has_adjoint(&self) -> bool {
481 self.a.has_adjoint() && self.b.has_adjoint()
482 }
483}
484
485pub struct FunctionOperator<F> {
487 shape: (usize, usize),
488 matvec_fn: MatVecFn<F>,
489 rmatvec_fn: Option<MatVecFn<F>>,
490}
491
492impl<F: Float + 'static> FunctionOperator<F> {
493 #[allow(dead_code)]
495 pub fn new<MV, RMV>(shape: (usize, usize), matvec_fn: MV, rmatvec_fn: Option<RMV>) -> Self
496 where
497 MV: Fn(&[F]) -> SparseResult<Vec<F>> + Send + Sync + 'static,
498 RMV: Fn(&[F]) -> SparseResult<Vec<F>> + Send + Sync + 'static,
499 {
500 Self {
501 shape,
502 matvec_fn: Box::new(matvec_fn),
503 rmatvec_fn: rmatvec_fn
504 .map(|f| Box::new(f) as Box<dyn Fn(&[F]) -> SparseResult<Vec<F>> + Send + Sync>),
505 }
506 }
507
508 #[allow(dead_code)]
510 pub fn from_function<FMv>(shape: (usize, usize), matvec_fn: FMv) -> Self
511 where
512 FMv: Fn(&[F]) -> SparseResult<Vec<F>> + Send + Sync + 'static,
513 {
514 Self::new(shape, matvec_fn, None::<fn(&[F]) -> SparseResult<Vec<F>>>)
515 }
516}
517
518impl<F: Float> LinearOperator<F> for FunctionOperator<F> {
519 fn shape(&self) -> (usize, usize) {
520 self.shape
521 }
522
523 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
524 (self.matvec_fn)(x)
525 }
526
527 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
528 match &self.rmatvec_fn {
529 Some(f) => f(x),
530 None => Err(SparseError::OperationNotSupported(
531 "adjoint not implemented for this function operator".to_string(),
532 )),
533 }
534 }
535
536 fn has_adjoint(&self) -> bool {
537 self.rmatvec_fn.is_some()
538 }
539}
540
541pub struct InverseOperator<F> {
544 original: Box<dyn LinearOperator<F>>,
545 solver_fn: SolverFn<F>,
546}
547
548impl<F: Float> InverseOperator<F> {
549 #[allow(dead_code)]
551 pub fn new<S>(original: Box<dyn LinearOperator<F>>, solver_fn: S) -> SparseResult<Self>
552 where
553 S: Fn(&[F]) -> SparseResult<Vec<F>> + Send + Sync + 'static,
554 {
555 let (rows, cols) = original.shape();
556 if rows != cols {
557 return Err(SparseError::ValueError(
558 "Cannot invert non-square operator".to_string(),
559 ));
560 }
561
562 Ok(Self {
563 original,
564 solver_fn: Box::new(solver_fn),
565 })
566 }
567}
568
569impl<F: Float> LinearOperator<F> for InverseOperator<F> {
570 fn shape(&self) -> (usize, usize) {
571 self.original.shape()
572 }
573
574 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
575 (self.solver_fn)(x)
577 }
578
579 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
580 if !self.original.has_adjoint() {
583 return Err(SparseError::OperationNotSupported(
584 "adjoint not supported for original operator".to_string(),
585 ));
586 }
587
588 Err(SparseError::OperationNotSupported(
591 "adjoint of inverse operator not yet implemented".to_string(),
592 ))
593 }
594
595 fn has_adjoint(&self) -> bool {
596 false }
598}
599
600pub struct TransposeOperator<F> {
602 original: Box<dyn LinearOperator<F>>,
603}
604
605impl<F: Float + NumAssign> TransposeOperator<F> {
606 pub fn new(original: Box<dyn LinearOperator<F>>) -> Self {
608 Self { original }
609 }
610}
611
612impl<F: Float + NumAssign> LinearOperator<F> for TransposeOperator<F> {
613 fn shape(&self) -> (usize, usize) {
614 let (rows, cols) = self.original.shape();
615 (cols, rows) }
617
618 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
619 self.original.rmatvec(x)
621 }
622
623 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
624 self.original.matvec(x)
626 }
627
628 fn has_adjoint(&self) -> bool {
629 true }
631}
632
633pub struct AdjointOperator<F> {
635 original: Box<dyn LinearOperator<F>>,
636}
637
638impl<F: Float + NumAssign> AdjointOperator<F> {
639 pub fn new(original: Box<dyn LinearOperator<F>>) -> SparseResult<Self> {
641 if !original.has_adjoint() {
642 return Err(SparseError::OperationNotSupported(
643 "Original operator does not support adjoint operations".to_string(),
644 ));
645 }
646 Ok(Self { original })
647 }
648}
649
650impl<F: Float + NumAssign> LinearOperator<F> for AdjointOperator<F> {
651 fn shape(&self) -> (usize, usize) {
652 let (rows, cols) = self.original.shape();
653 (cols, rows) }
655
656 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
657 self.original.rmatvec(x)
658 }
659
660 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
661 self.original.matvec(x)
662 }
663
664 fn has_adjoint(&self) -> bool {
665 true
666 }
667}
668
669pub struct DifferenceOperator<F> {
671 a: Box<dyn LinearOperator<F>>,
672 b: Box<dyn LinearOperator<F>>,
673}
674
675impl<F: Float + NumAssign> DifferenceOperator<F> {
676 pub fn new(a: Box<dyn LinearOperator<F>>, b: Box<dyn LinearOperator<F>>) -> SparseResult<Self> {
678 if a.shape() != b.shape() {
679 return Err(SparseError::ShapeMismatch {
680 expected: a.shape(),
681 found: b.shape(),
682 });
683 }
684 Ok(Self { a, b })
685 }
686}
687
688impl<F: Float + NumAssign> LinearOperator<F> for DifferenceOperator<F> {
689 fn shape(&self) -> (usize, usize) {
690 self.a.shape()
691 }
692
693 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
694 let a_result = self.a.matvec(x)?;
695 let b_result = self.b.matvec(x)?;
696 Ok(a_result
697 .iter()
698 .zip(&b_result)
699 .map(|(&a, &b)| a - b)
700 .collect())
701 }
702
703 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
704 if !self.a.has_adjoint() || !self.b.has_adjoint() {
705 return Err(SparseError::OperationNotSupported(
706 "adjoint not supported for one or both operators".to_string(),
707 ));
708 }
709 let a_result = self.a.rmatvec(x)?;
710 let b_result = self.b.rmatvec(x)?;
711 Ok(a_result
712 .iter()
713 .zip(&b_result)
714 .map(|(&a, &b)| a - b)
715 .collect())
716 }
717
718 fn has_adjoint(&self) -> bool {
719 self.a.has_adjoint() && self.b.has_adjoint()
720 }
721}
722
723pub struct ScaledOperator<F> {
725 alpha: F,
726 operator: Box<dyn LinearOperator<F>>,
727}
728
729impl<F: Float + NumAssign> ScaledOperator<F> {
730 pub fn new(alpha: F, operator: Box<dyn LinearOperator<F>>) -> Self {
732 Self { alpha, operator }
733 }
734}
735
736impl<F: Float + NumAssign> LinearOperator<F> for ScaledOperator<F> {
737 fn shape(&self) -> (usize, usize) {
738 self.operator.shape()
739 }
740
741 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
742 let result = self.operator.matvec(x)?;
743 Ok(result.iter().map(|&val| self.alpha * val).collect())
744 }
745
746 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
747 if !self.operator.has_adjoint() {
748 return Err(SparseError::OperationNotSupported(
749 "adjoint not supported for underlying operator".to_string(),
750 ));
751 }
752 let result = self.operator.rmatvec(x)?;
753 Ok(result.iter().map(|&val| self.alpha * val).collect())
754 }
755
756 fn has_adjoint(&self) -> bool {
757 self.operator.has_adjoint()
758 }
759}
760
761pub struct ChainOperator<F> {
763 operators: Vec<Box<dyn LinearOperator<F>>>,
764 totalshape: (usize, usize),
765}
766
767impl<F: Float + NumAssign> ChainOperator<F> {
768 #[allow(dead_code)]
771 pub fn new(operators: Vec<Box<dyn LinearOperator<F>>>) -> SparseResult<Self> {
772 if operators.is_empty() {
773 return Err(SparseError::ValueError(
774 "Cannot create chain with no operators".to_string(),
775 ));
776 }
777
778 #[allow(clippy::needless_range_loop)]
780 for i in 0..operators.len() - 1 {
781 let (_, a_cols) = operators[i].shape();
782 let (b_rows, _) = operators[i + 1].shape();
783 if a_cols != b_rows {
784 return Err(SparseError::DimensionMismatch {
785 expected: a_cols,
786 found: b_rows,
787 });
788 }
789 }
790
791 let (first_rows, _) = operators[0].shape();
792 let (_, last_cols) = operators.last().unwrap().shape();
793 let totalshape = (first_rows, last_cols);
794
795 Ok(Self {
796 operators,
797 totalshape,
798 })
799 }
800}
801
802impl<F: Float + NumAssign> LinearOperator<F> for ChainOperator<F> {
803 fn shape(&self) -> (usize, usize) {
804 self.totalshape
805 }
806
807 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
808 let mut result = x.to_vec();
809 for op in self.operators.iter().rev() {
811 result = op.matvec(&result)?;
812 }
813 Ok(result)
814 }
815
816 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
817 for op in &self.operators {
819 if !op.has_adjoint() {
820 return Err(SparseError::OperationNotSupported(
821 "adjoint not supported for all operators in chain".to_string(),
822 ));
823 }
824 }
825
826 let mut result = x.to_vec();
827 for op in &self.operators {
829 result = op.rmatvec(&result)?;
830 }
831 Ok(result)
832 }
833
834 fn has_adjoint(&self) -> bool {
835 self.operators.iter().all(|op| op.has_adjoint())
836 }
837}
838
839pub struct PowerOperator<F> {
841 operator: Box<dyn LinearOperator<F>>,
842 power: usize,
843}
844
845impl<F: Float + NumAssign> PowerOperator<F> {
846 pub fn new(operator: Box<dyn LinearOperator<F>>, power: usize) -> SparseResult<Self> {
848 let (rows, cols) = operator.shape();
849 if rows != cols {
850 return Err(SparseError::ValueError(
851 "Can only compute powers of square operators".to_string(),
852 ));
853 }
854 if power == 0 {
855 return Err(SparseError::ValueError(
856 "Power must be positive".to_string(),
857 ));
858 }
859 Ok(Self { operator, power })
860 }
861}
862
863impl<F: Float + NumAssign> LinearOperator<F> for PowerOperator<F> {
864 fn shape(&self) -> (usize, usize) {
865 self.operator.shape()
866 }
867
868 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
869 let mut result = x.to_vec();
870 for _ in 0..self.power {
871 result = self.operator.matvec(&result)?;
872 }
873 Ok(result)
874 }
875
876 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
877 if !self.operator.has_adjoint() {
878 return Err(SparseError::OperationNotSupported(
879 "adjoint not supported for underlying operator".to_string(),
880 ));
881 }
882 let mut result = x.to_vec();
883 for _ in 0..self.power {
884 result = self.operator.rmatvec(&result)?;
885 }
886 Ok(result)
887 }
888
889 fn has_adjoint(&self) -> bool {
890 self.operator.has_adjoint()
891 }
892}
893
894#[allow(dead_code)]
896pub trait LinearOperatorExt<F: Float + NumAssign>: LinearOperator<F> {
897 fn add(&self, other: Box<dyn LinearOperator<F>>) -> SparseResult<Box<dyn LinearOperator<F>>>;
899
900 fn sub(&self, other: Box<dyn LinearOperator<F>>) -> SparseResult<Box<dyn LinearOperator<F>>>;
902
903 fn mul(&self, other: Box<dyn LinearOperator<F>>) -> SparseResult<Box<dyn LinearOperator<F>>>;
905
906 fn scale(&self, alpha: F) -> Box<dyn LinearOperator<F>>;
908
909 fn transpose(&self) -> Box<dyn LinearOperator<F>>;
911
912 fn adjoint(&self) -> SparseResult<Box<dyn LinearOperator<F>>>;
914
915 fn pow(&self, n: usize) -> SparseResult<Box<dyn LinearOperator<F>>>;
917}
918
919macro_rules! impl_linear_operator_ext {
921 ($typ:ty) => {
922 impl<F: Float + NumAssign + Copy + 'static> LinearOperatorExt<F> for $typ {
923 fn add(
924 &self,
925 other: Box<dyn LinearOperator<F>>,
926 ) -> SparseResult<Box<dyn LinearOperator<F>>> {
927 let self_box = Box::new(self.clone());
928 Ok(Box::new(SumOperator::new(self_box, other)?))
929 }
930
931 fn sub(
932 &self,
933 other: Box<dyn LinearOperator<F>>,
934 ) -> SparseResult<Box<dyn LinearOperator<F>>> {
935 let self_box = Box::new(self.clone());
936 Ok(Box::new(DifferenceOperator::new(self_box, other)?))
937 }
938
939 fn mul(
940 &self,
941 other: Box<dyn LinearOperator<F>>,
942 ) -> SparseResult<Box<dyn LinearOperator<F>>> {
943 let self_box = Box::new(self.clone());
944 Ok(Box::new(ProductOperator::new(self_box, other)?))
945 }
946
947 fn scale(&self, alpha: F) -> Box<dyn LinearOperator<F>> {
948 let self_box = Box::new(self.clone());
949 Box::new(ScaledOperator::new(alpha, self_box))
950 }
951
952 fn transpose(&self) -> Box<dyn LinearOperator<F>> {
953 let self_box = Box::new(self.clone());
954 Box::new(TransposeOperator::new(self_box))
955 }
956
957 fn adjoint(&self) -> SparseResult<Box<dyn LinearOperator<F>>> {
958 let self_box = Box::new(self.clone());
959 Ok(Box::new(AdjointOperator::new(self_box)?))
960 }
961
962 fn pow(&self, n: usize) -> SparseResult<Box<dyn LinearOperator<F>>> {
963 let self_box = Box::new(self.clone());
964 Ok(Box::new(PowerOperator::new(self_box, n)?))
965 }
966 }
967 };
968}
969
970impl_linear_operator_ext!(IdentityOperator<F>);
972impl_linear_operator_ext!(ScaledIdentityOperator<F>);
973impl_linear_operator_ext!(DiagonalOperator<F>);
974impl_linear_operator_ext!(ZeroOperator<F>);
975
976#[allow(dead_code)]
979pub fn add_operators<F: Float + NumAssign + 'static>(
980 left: Box<dyn LinearOperator<F>>,
981 right: Box<dyn LinearOperator<F>>,
982) -> SparseResult<Box<dyn LinearOperator<F>>> {
983 Ok(Box::new(SumOperator::new(left, right)?))
984}
985
986#[allow(dead_code)]
988pub fn subtract_operators<F: Float + NumAssign + 'static>(
989 left: Box<dyn LinearOperator<F>>,
990 right: Box<dyn LinearOperator<F>>,
991) -> SparseResult<Box<dyn LinearOperator<F>>> {
992 Ok(Box::new(DifferenceOperator::new(left, right)?))
993}
994
995#[allow(dead_code)]
997pub fn multiply_operators<F: Float + NumAssign + 'static>(
998 left: Box<dyn LinearOperator<F>>,
999 right: Box<dyn LinearOperator<F>>,
1000) -> SparseResult<Box<dyn LinearOperator<F>>> {
1001 Ok(Box::new(ProductOperator::new(left, right)?))
1002}
1003
1004#[allow(dead_code)]
1006pub fn scale_operator<F: Float + NumAssign + 'static>(
1007 alpha: F,
1008 operator: Box<dyn LinearOperator<F>>,
1009) -> Box<dyn LinearOperator<F>> {
1010 Box::new(ScaledOperator::new(alpha, operator))
1011}
1012
1013#[allow(dead_code)]
1015pub fn transpose_operator<F: Float + NumAssign + 'static>(
1016 operator: Box<dyn LinearOperator<F>>,
1017) -> Box<dyn LinearOperator<F>> {
1018 Box::new(TransposeOperator::new(operator))
1019}
1020
1021#[allow(dead_code)]
1023pub fn adjoint_operator<F: Float + NumAssign + 'static>(
1024 operator: Box<dyn LinearOperator<F>>,
1025) -> SparseResult<Box<dyn LinearOperator<F>>> {
1026 Ok(Box::new(AdjointOperator::new(operator)?))
1027}
1028
1029#[allow(dead_code)]
1031pub fn compose_operators<F: Float + NumAssign + 'static>(
1032 operators: Vec<Box<dyn LinearOperator<F>>>,
1033) -> SparseResult<Box<dyn LinearOperator<F>>> {
1034 Ok(Box::new(ChainOperator::new(operators)?))
1035}
1036
1037#[allow(dead_code)]
1039pub fn power_operator<F: Float + NumAssign + 'static>(
1040 operator: Box<dyn LinearOperator<F>>,
1041 n: usize,
1042) -> SparseResult<Box<dyn LinearOperator<F>>> {
1043 Ok(Box::new(PowerOperator::new(operator, n)?))
1044}
1045
1046#[cfg(test)]
1047mod tests {
1048 use super::*;
1049
1050 #[test]
1051 fn test_identity_operator() {
1052 let op = IdentityOperator::<f64>::new(3);
1053 let x = vec![1.0, 2.0, 3.0];
1054 let y = op.matvec(&x).unwrap();
1055 assert_eq!(x, y);
1056 }
1057
1058 #[test]
1059 fn test_scaled_identity_operator() {
1060 let op = ScaledIdentityOperator::new(3, 2.0);
1061 let x = vec![1.0, 2.0, 3.0];
1062 let y = op.matvec(&x).unwrap();
1063 assert_eq!(y, vec![2.0, 4.0, 6.0]);
1064 }
1065
1066 #[test]
1067 fn test_diagonal_operator() {
1068 let op = DiagonalOperator::new(vec![2.0, 3.0, 4.0]);
1069 let x = vec![1.0, 2.0, 3.0];
1070 let y = op.matvec(&x).unwrap();
1071 assert_eq!(y, vec![2.0, 6.0, 12.0]);
1072 }
1073
1074 #[test]
1075 fn test_zero_operator() {
1076 let op = ZeroOperator::<f64>::new(3, 3);
1077 let x = vec![1.0, 2.0, 3.0];
1078 let y = op.matvec(&x).unwrap();
1079 assert_eq!(y, vec![0.0, 0.0, 0.0]);
1080 }
1081
1082 #[test]
1083 fn test_sum_operator() {
1084 let id = Box::new(IdentityOperator::<f64>::new(3));
1085 let scaled = Box::new(ScaledIdentityOperator::new(3, 2.0));
1086 let sum = SumOperator::new(id, scaled).unwrap();
1087 let x = vec![1.0, 2.0, 3.0];
1088 let y = sum.matvec(&x).unwrap();
1089 assert_eq!(y, vec![3.0, 6.0, 9.0]); }
1091
1092 #[test]
1093 fn test_product_operator() {
1094 let id = Box::new(IdentityOperator::<f64>::new(3));
1095 let scaled = Box::new(ScaledIdentityOperator::new(3, 2.0));
1096 let product = ProductOperator::new(scaled, id).unwrap();
1097 let x = vec![1.0, 2.0, 3.0];
1098 let y = product.matvec(&x).unwrap();
1099 assert_eq!(y, vec![2.0, 4.0, 6.0]); }
1101
1102 #[test]
1103 fn test_difference_operator() {
1104 let scaled_3 = Box::new(ScaledIdentityOperator::new(3, 3.0));
1105 let scaled_2 = Box::new(ScaledIdentityOperator::new(3, 2.0));
1106 let diff = DifferenceOperator::new(scaled_3, scaled_2).unwrap();
1107 let x = vec![1.0, 2.0, 3.0];
1108 let y = diff.matvec(&x).unwrap();
1109 assert_eq!(y, vec![1.0, 2.0, 3.0]); }
1111
1112 #[test]
1113 fn test_scaled_operator() {
1114 let id = Box::new(IdentityOperator::<f64>::new(3));
1115 let scaled = ScaledOperator::new(5.0, id);
1116 let x = vec![1.0, 2.0, 3.0];
1117 let y = scaled.matvec(&x).unwrap();
1118 assert_eq!(y, vec![5.0, 10.0, 15.0]); }
1120
1121 #[test]
1122 fn test_transpose_operator() {
1123 let diag = Box::new(DiagonalOperator::new(vec![2.0, 3.0, 4.0]));
1124 let transpose = TransposeOperator::new(diag);
1125 let x = vec![1.0, 2.0, 3.0];
1126 let y = transpose.matvec(&x).unwrap();
1127 assert_eq!(y, vec![2.0, 6.0, 12.0]);
1129 }
1130
1131 #[test]
1132 fn test_adjoint_operator() {
1133 let diag = Box::new(DiagonalOperator::new(vec![2.0, 3.0, 4.0]));
1134 let adjoint = AdjointOperator::new(diag).unwrap();
1135 let x = vec![1.0, 2.0, 3.0];
1136 let y = adjoint.matvec(&x).unwrap();
1137 assert_eq!(y, vec![2.0, 6.0, 12.0]);
1139 }
1140
1141 #[test]
1142 fn test_chain_operator() {
1143 let op1 = Box::new(ScaledIdentityOperator::new(3, 2.0));
1144 let op2 = Box::new(ScaledIdentityOperator::new(3, 3.0));
1145 let chain = ChainOperator::new(vec![op1, op2]).unwrap();
1146 let x = vec![1.0, 2.0, 3.0];
1147 let y = chain.matvec(&x).unwrap();
1148 assert_eq!(y, vec![6.0, 12.0, 18.0]);
1150 }
1151
1152 #[test]
1153 fn test_power_operator() {
1154 let scaled = Box::new(ScaledIdentityOperator::new(3, 2.0));
1155 let power = PowerOperator::new(scaled, 3).unwrap();
1156 let x = vec![1.0, 2.0, 3.0];
1157 let y = power.matvec(&x).unwrap();
1158 assert_eq!(y, vec![8.0, 16.0, 24.0]);
1160 }
1161
1162 #[test]
1163 fn test_composition_utility_functions() {
1164 let id = Box::new(IdentityOperator::<f64>::new(3));
1165 let scaled = Box::new(ScaledIdentityOperator::new(3, 2.0));
1166
1167 let sum = add_operators(id.clone(), scaled.clone()).unwrap();
1169 let x = vec![1.0, 2.0, 3.0];
1170 let y = sum.matvec(&x).unwrap();
1171 assert_eq!(y, vec![3.0, 6.0, 9.0]); let diff = subtract_operators(scaled.clone(), id.clone()).unwrap();
1175 let y2 = diff.matvec(&x).unwrap();
1176 assert_eq!(y2, vec![1.0, 2.0, 3.0]); let product = multiply_operators(scaled.clone(), id.clone()).unwrap();
1180 let y3 = product.matvec(&x).unwrap();
1181 assert_eq!(y3, vec![2.0, 4.0, 6.0]); let scaled_op = scale_operator(3.0, id.clone());
1185 let y4 = scaled_op.matvec(&x).unwrap();
1186 assert_eq!(y4, vec![3.0, 6.0, 9.0]); let transpose = transpose_operator(scaled.clone());
1190 let y5 = transpose.matvec(&x).unwrap();
1191 assert_eq!(y5, vec![2.0, 4.0, 6.0]); let ops: Vec<Box<dyn LinearOperator<f64>>> = vec![scaled.clone(), id.clone()];
1195 let composed = compose_operators(ops).unwrap();
1196 let y6 = composed.matvec(&x).unwrap();
1197 assert_eq!(y6, vec![2.0, 4.0, 6.0]); let power = power_operator(scaled.clone(), 2).unwrap();
1201 let y7 = power.matvec(&x).unwrap();
1202 assert_eq!(y7, vec![4.0, 8.0, 12.0]); }
1204
1205 #[test]
1206 fn test_dimension_mismatch_errors() {
1207 let op1 = Box::new(IdentityOperator::<f64>::new(3));
1208 let op2 = Box::new(IdentityOperator::<f64>::new(4));
1209
1210 assert!(SumOperator::new(op1.clone(), op2.clone()).is_err());
1212
1213 assert!(DifferenceOperator::new(op1.clone(), op2.clone()).is_err());
1215
1216 let rect1 = Box::new(ZeroOperator::<f64>::new(3, 4));
1218 let rect2 = Box::new(ZeroOperator::<f64>::new(5, 3));
1219 assert!(ProductOperator::new(rect1, rect2).is_err());
1220 }
1221
1222 #[test]
1223 fn test_adjoint_not_supported_error() {
1224 let func_op = Box::new(FunctionOperator::from_function((3, 3), |x: &[f64]| {
1226 Ok(x.to_vec())
1227 }));
1228
1229 assert!(AdjointOperator::new(func_op).is_err());
1231 }
1232
1233 #[test]
1234 fn test_power_operator_errors() {
1235 let rect_op = Box::new(ZeroOperator::<f64>::new(3, 4));
1236
1237 assert!(PowerOperator::new(rect_op, 2).is_err());
1239
1240 let square_op = Box::new(IdentityOperator::<f64>::new(3));
1241
1242 assert!(PowerOperator::new(square_op, 0).is_err());
1244 }
1245}