1use crate::error::{SparseError, SparseResult};
4use num_traits::{Float, NumAssign};
5use std::fmt::Debug;
6use std::iter::Sum;
7use std::marker::PhantomData;
8
9pub trait LinearOperator<F: Float> {
14 fn shape(&self) -> (usize, usize);
16
17 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>>;
19
20 fn matmat(&self, x: &[Vec<F>]) -> SparseResult<Vec<Vec<F>>> {
23 let mut result = Vec::new();
24 for col in x {
25 result.push(self.matvec(col)?);
26 }
27 Ok(result)
28 }
29
30 fn rmatvec(&self, _x: &[F]) -> SparseResult<Vec<F>> {
33 Err(crate::error::SparseError::OperationNotSupported(
34 "adjoint not implemented for this operator".to_string(),
35 ))
36 }
37
38 fn rmatmat(&self, x: &[Vec<F>]) -> SparseResult<Vec<Vec<F>>> {
41 let mut result = Vec::new();
42 for col in x {
43 result.push(self.rmatvec(col)?);
44 }
45 Ok(result)
46 }
47
48 fn has_adjoint(&self) -> bool {
50 false
51 }
52}
53
54pub struct IdentityOperator<F> {
56 size: usize,
57 _phantom: PhantomData<F>,
58}
59
60impl<F> IdentityOperator<F> {
61 pub fn new(size: usize) -> Self {
63 Self {
64 size,
65 _phantom: PhantomData,
66 }
67 }
68}
69
70impl<F: Float> LinearOperator<F> for IdentityOperator<F> {
71 fn shape(&self) -> (usize, usize) {
72 (self.size, self.size)
73 }
74
75 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
76 if x.len() != self.size {
77 return Err(crate::error::SparseError::DimensionMismatch {
78 expected: self.size,
79 found: x.len(),
80 });
81 }
82 Ok(x.to_vec())
83 }
84
85 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
86 self.matvec(x)
87 }
88
89 fn has_adjoint(&self) -> bool {
90 true
91 }
92}
93
94pub struct ScaledIdentityOperator<F> {
96 size: usize,
97 scale: F,
98}
99
100impl<F: Float> ScaledIdentityOperator<F> {
101 pub fn new(size: usize, scale: F) -> Self {
103 Self { size, scale }
104 }
105}
106
107impl<F: Float + NumAssign> LinearOperator<F> for ScaledIdentityOperator<F> {
108 fn shape(&self) -> (usize, usize) {
109 (self.size, self.size)
110 }
111
112 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
113 if x.len() != self.size {
114 return Err(crate::error::SparseError::DimensionMismatch {
115 expected: self.size,
116 found: x.len(),
117 });
118 }
119 Ok(x.iter().map(|&xi| xi * self.scale).collect())
120 }
121
122 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
123 self.matvec(x)
125 }
126
127 fn has_adjoint(&self) -> bool {
128 true
129 }
130}
131
132pub struct DiagonalOperator<F> {
134 diagonal: Vec<F>,
135}
136
137impl<F: Float> DiagonalOperator<F> {
138 pub fn new(diagonal: Vec<F>) -> Self {
140 Self { diagonal }
141 }
142
143 pub fn diagonal(&self) -> &[F] {
145 &self.diagonal
146 }
147}
148
149impl<F: Float + NumAssign> LinearOperator<F> for DiagonalOperator<F> {
150 fn shape(&self) -> (usize, usize) {
151 let n = self.diagonal.len();
152 (n, n)
153 }
154
155 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
156 if x.len() != self.diagonal.len() {
157 return Err(crate::error::SparseError::DimensionMismatch {
158 expected: self.diagonal.len(),
159 found: x.len(),
160 });
161 }
162 Ok(x.iter()
163 .zip(&self.diagonal)
164 .map(|(&xi, &di)| xi * di)
165 .collect())
166 }
167
168 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
169 self.matvec(x)
171 }
172
173 fn has_adjoint(&self) -> bool {
174 true
175 }
176}
177
178pub struct ZeroOperator<F> {
180 shape: (usize, usize),
181 _phantom: PhantomData<F>,
182}
183
184impl<F> ZeroOperator<F> {
185 #[allow(dead_code)]
187 pub fn new(rows: usize, cols: usize) -> Self {
188 Self {
189 shape: (rows, cols),
190 _phantom: PhantomData,
191 }
192 }
193}
194
195impl<F: Float> LinearOperator<F> for ZeroOperator<F> {
196 fn shape(&self) -> (usize, usize) {
197 self.shape
198 }
199
200 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
201 if x.len() != self.shape.1 {
202 return Err(crate::error::SparseError::DimensionMismatch {
203 expected: self.shape.1,
204 found: x.len(),
205 });
206 }
207 Ok(vec![F::zero(); self.shape.0])
208 }
209
210 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
211 if x.len() != self.shape.0 {
212 return Err(crate::error::SparseError::DimensionMismatch {
213 expected: self.shape.0,
214 found: x.len(),
215 });
216 }
217 Ok(vec![F::zero(); self.shape.1])
218 }
219
220 fn has_adjoint(&self) -> bool {
221 true
222 }
223}
224
225pub trait AsLinearOperator<F: Float> {
227 fn as_linear_operator(&self) -> Box<dyn LinearOperator<F>>;
229}
230
231pub struct MatrixLinearOperator<F, M> {
233 matrix: M,
234 _phantom: PhantomData<F>,
235}
236
237impl<F, M> MatrixLinearOperator<F, M> {
238 pub fn new(matrix: M) -> Self {
240 Self {
241 matrix,
242 _phantom: PhantomData,
243 }
244 }
245}
246
247use crate::csr::CsrMatrix;
249
250impl<F: Float + NumAssign + Sum + 'static + Debug> LinearOperator<F>
251 for MatrixLinearOperator<F, CsrMatrix<F>>
252{
253 fn shape(&self) -> (usize, usize) {
254 (self.matrix.rows(), self.matrix.cols())
255 }
256
257 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
258 if x.len() != self.matrix.cols() {
259 return Err(SparseError::DimensionMismatch {
260 expected: self.matrix.cols(),
261 found: x.len(),
262 });
263 }
264
265 let mut result = vec![F::zero(); self.matrix.rows()];
267 for (row, result_elem) in result.iter_mut().enumerate().take(self.matrix.rows()) {
268 let row_range = self.matrix.row_range(row);
269 let row_indices = &self.matrix.col_indices()[row_range.clone()];
270 let row_data = &self.matrix.data[row_range];
271
272 let mut sum = F::zero();
273 for (col_idx, &col) in row_indices.iter().enumerate() {
274 sum += row_data[col_idx] * x[col];
275 }
276 *result_elem = sum;
277 }
278 Ok(result)
279 }
280
281 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
282 let transposed = self.matrix.transpose();
284 MatrixLinearOperator::new(transposed).matvec(x)
285 }
286
287 fn has_adjoint(&self) -> bool {
288 true
289 }
290}
291
292impl<F: Float + NumAssign + Sum + 'static + Debug> AsLinearOperator<F> for CsrMatrix<F> {
293 fn as_linear_operator(&self) -> Box<dyn LinearOperator<F>> {
294 Box::new(MatrixLinearOperator::new(self.clone()))
295 }
296}
297
298pub struct SumOperator<F> {
301 a: Box<dyn LinearOperator<F>>,
302 b: Box<dyn LinearOperator<F>>,
303}
304
305impl<F: Float + NumAssign> SumOperator<F> {
306 #[allow(dead_code)]
308 pub fn new(a: Box<dyn LinearOperator<F>>, b: Box<dyn LinearOperator<F>>) -> SparseResult<Self> {
309 if a.shape() != b.shape() {
310 return Err(crate::error::SparseError::ShapeMismatch {
311 expected: a.shape(),
312 found: b.shape(),
313 });
314 }
315 Ok(Self { a, b })
316 }
317}
318
319impl<F: Float + NumAssign> LinearOperator<F> for SumOperator<F> {
320 fn shape(&self) -> (usize, usize) {
321 self.a.shape()
322 }
323
324 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
325 let a_result = self.a.matvec(x)?;
326 let b_result = self.b.matvec(x)?;
327 Ok(a_result
328 .iter()
329 .zip(&b_result)
330 .map(|(&a, &b)| a + b)
331 .collect())
332 }
333
334 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
335 if !self.a.has_adjoint() || !self.b.has_adjoint() {
336 return Err(crate::error::SparseError::OperationNotSupported(
337 "adjoint not supported for one or both operators".to_string(),
338 ));
339 }
340 let a_result = self.a.rmatvec(x)?;
341 let b_result = self.b.rmatvec(x)?;
342 Ok(a_result
343 .iter()
344 .zip(&b_result)
345 .map(|(&a, &b)| a + b)
346 .collect())
347 }
348
349 fn has_adjoint(&self) -> bool {
350 self.a.has_adjoint() && self.b.has_adjoint()
351 }
352}
353
354pub struct ProductOperator<F> {
356 a: Box<dyn LinearOperator<F>>,
357 b: Box<dyn LinearOperator<F>>,
358}
359
360impl<F: Float + NumAssign> ProductOperator<F> {
361 #[allow(dead_code)]
363 pub fn new(a: Box<dyn LinearOperator<F>>, b: Box<dyn LinearOperator<F>>) -> SparseResult<Self> {
364 let (_a_rows, a_cols) = a.shape();
365 let (b_rows, _b_cols) = b.shape();
366 if a_cols != b_rows {
367 return Err(crate::error::SparseError::DimensionMismatch {
368 expected: a_cols,
369 found: b_rows,
370 });
371 }
372 Ok(Self { a, b })
373 }
374}
375
376impl<F: Float + NumAssign> LinearOperator<F> for ProductOperator<F> {
377 fn shape(&self) -> (usize, usize) {
378 let (a_rows, _) = self.a.shape();
379 let (_, b_cols) = self.b.shape();
380 (a_rows, b_cols)
381 }
382
383 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
384 let b_result = self.b.matvec(x)?;
385 self.a.matvec(&b_result)
386 }
387
388 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
389 if !self.a.has_adjoint() || !self.b.has_adjoint() {
390 return Err(crate::error::SparseError::OperationNotSupported(
391 "adjoint not supported for one or both operators".to_string(),
392 ));
393 }
394 let a_result = self.a.rmatvec(x)?;
396 self.b.rmatvec(&a_result)
397 }
398
399 fn has_adjoint(&self) -> bool {
400 self.a.has_adjoint() && self.b.has_adjoint()
401 }
402}
403
404#[cfg(test)]
405mod tests {
406 use super::*;
407
408 #[test]
409 fn test_identity_operator() {
410 let op = IdentityOperator::<f64>::new(3);
411 let x = vec![1.0, 2.0, 3.0];
412 let y = op.matvec(&x).unwrap();
413 assert_eq!(x, y);
414 }
415
416 #[test]
417 fn test_scaled_identity_operator() {
418 let op = ScaledIdentityOperator::new(3, 2.0);
419 let x = vec![1.0, 2.0, 3.0];
420 let y = op.matvec(&x).unwrap();
421 assert_eq!(y, vec![2.0, 4.0, 6.0]);
422 }
423
424 #[test]
425 fn test_diagonal_operator() {
426 let op = DiagonalOperator::new(vec![2.0, 3.0, 4.0]);
427 let x = vec![1.0, 2.0, 3.0];
428 let y = op.matvec(&x).unwrap();
429 assert_eq!(y, vec![2.0, 6.0, 12.0]);
430 }
431
432 #[test]
433 fn test_zero_operator() {
434 let op = ZeroOperator::<f64>::new(3, 3);
435 let x = vec![1.0, 2.0, 3.0];
436 let y = op.matvec(&x).unwrap();
437 assert_eq!(y, vec![0.0, 0.0, 0.0]);
438 }
439
440 #[test]
441 fn test_sum_operator() {
442 let id = Box::new(IdentityOperator::<f64>::new(3));
443 let scaled = Box::new(ScaledIdentityOperator::new(3, 2.0));
444 let sum = SumOperator::new(id, scaled).unwrap();
445 let x = vec![1.0, 2.0, 3.0];
446 let y = sum.matvec(&x).unwrap();
447 assert_eq!(y, vec![3.0, 6.0, 9.0]); }
449
450 #[test]
451 fn test_product_operator() {
452 let id = Box::new(IdentityOperator::<f64>::new(3));
453 let scaled = Box::new(ScaledIdentityOperator::new(3, 2.0));
454 let product = ProductOperator::new(scaled, id).unwrap();
455 let x = vec![1.0, 2.0, 3.0];
456 let y = product.matvec(&x).unwrap();
457 assert_eq!(y, vec![2.0, 4.0, 6.0]); }
459}