Skip to main content

tensors/
linalg.rs

1//! Pure Rust scalar linear algebra kernels and operators.
2
3use crate::array2::Array2;
4use crate::array3::Axis3;
5use crate::error::{Error, Result};
6use crate::numeric::Float;
7use crate::rand::SmallRng;
8use crate::view2::{ArrayView2, ArrayViewMut2};
9use crate::view3::ArrayView3;
10use crate::workspace::Workspace;
11use pulp::{Arch, Simd, WithSimd};
12use rayon::prelude::*;
13
14/// Compute a dot product.
15pub fn dot<T: Float>(x: &[T], y: &[T]) -> Result<T> {
16    if x.len() != y.len() {
17        return Err(Error::shape(vec![x.len()], vec![y.len()]));
18    }
19    Ok(x.iter().zip(y).map(|(&a, &b)| a * b).sum())
20}
21
22/// Compute `y += alpha * x`.
23pub fn axpy<T: Float>(alpha: T, x: &[T], y: &mut [T]) -> Result<()> {
24    if x.len() != y.len() {
25        return Err(Error::shape(vec![x.len()], vec![y.len()]));
26    }
27    for (yi, &xi) in y.iter_mut().zip(x) {
28        *yi += alpha * xi;
29    }
30    Ok(())
31}
32
33/// Euclidean norm.
34pub fn norm_l2<T: Float>(x: &[T]) -> T {
35    x.iter()
36        .copied()
37        .map(|value| value * value)
38        .sum::<T>()
39        .sqrt()
40}
41
42/// Compute a dot product using explicit SIMD for contiguous `f32` slices.
43pub fn dot_f32(x: &[f32], y: &[f32]) -> Result<f32> {
44    if x.len() != y.len() {
45        return Err(Error::shape(vec![x.len()], vec![y.len()]));
46    }
47    Ok(Arch::new().dispatch(DotF32 { x, y }))
48}
49
50/// Compute a dot product using explicit SIMD for contiguous `f64` slices.
51pub fn dot_f64(x: &[f64], y: &[f64]) -> Result<f64> {
52    if x.len() != y.len() {
53        return Err(Error::shape(vec![x.len()], vec![y.len()]));
54    }
55    Ok(Arch::new().dispatch(DotF64 { x, y }))
56}
57
58/// Compute `y += alpha * x` using explicit SIMD for contiguous `f32` slices.
59pub fn axpy_f32(alpha: f32, x: &[f32], y: &mut [f32]) -> Result<()> {
60    if x.len() != y.len() {
61        return Err(Error::shape(vec![x.len()], vec![y.len()]));
62    }
63    Arch::new().dispatch(AxpyF32 { alpha, x, y });
64    Ok(())
65}
66
67/// Compute `y += alpha * x` using explicit SIMD for contiguous `f64` slices.
68pub fn axpy_f64(alpha: f64, x: &[f64], y: &mut [f64]) -> Result<()> {
69    if x.len() != y.len() {
70        return Err(Error::shape(vec![x.len()], vec![y.len()]));
71    }
72    Arch::new().dispatch(AxpyF64 { alpha, x, y });
73    Ok(())
74}
75
76/// Euclidean norm using explicit SIMD for contiguous `f32` slices.
77pub fn norm_l2_f32(x: &[f32]) -> f32 {
78    dot_f32(x, x)
79        .expect("matching input slices are valid")
80        .sqrt()
81}
82
83/// Euclidean norm using explicit SIMD for contiguous `f64` slices.
84pub fn norm_l2_f64(x: &[f64]) -> f64 {
85    dot_f64(x, x)
86        .expect("matching input slices are valid")
87        .sqrt()
88}
89
90struct DotF32<'a> {
91    x: &'a [f32],
92    y: &'a [f32],
93}
94
95impl WithSimd for DotF32<'_> {
96    type Output = f32;
97
98    fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
99        let (x_head, x_tail) = S::as_simd_f32s(self.x);
100        let (y_head, y_tail) = S::as_simd_f32s(self.y);
101        let mut acc = simd.splat_f32s(0.0);
102        for (&x, &y) in x_head.iter().zip(y_head) {
103            acc = simd.mul_add_f32s(x, y, acc);
104        }
105        let mut sum = simd.reduce_sum_f32s(acc);
106        for (&x, &y) in x_tail.iter().zip(y_tail) {
107            sum += x * y;
108        }
109        sum
110    }
111}
112
113struct DotF64<'a> {
114    x: &'a [f64],
115    y: &'a [f64],
116}
117
118impl WithSimd for DotF64<'_> {
119    type Output = f64;
120
121    fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
122        let (x_head, x_tail) = S::as_simd_f64s(self.x);
123        let (y_head, y_tail) = S::as_simd_f64s(self.y);
124        let mut acc = simd.splat_f64s(0.0);
125        for (&x, &y) in x_head.iter().zip(y_head) {
126            acc = simd.mul_add_f64s(x, y, acc);
127        }
128        let mut sum = simd.reduce_sum_f64s(acc);
129        for (&x, &y) in x_tail.iter().zip(y_tail) {
130            sum += x * y;
131        }
132        sum
133    }
134}
135
136struct AxpyF32<'a> {
137    alpha: f32,
138    x: &'a [f32],
139    y: &'a mut [f32],
140}
141
142impl WithSimd for AxpyF32<'_> {
143    type Output = ();
144
145    fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
146        let (x_head, x_tail) = S::as_simd_f32s(self.x);
147        let (y_head, y_tail) = S::as_mut_simd_f32s(self.y);
148        let alpha = simd.splat_f32s(self.alpha);
149        for (y, &x) in y_head.iter_mut().zip(x_head) {
150            *y = simd.mul_add_f32s(alpha, x, *y);
151        }
152        for (y, &x) in y_tail.iter_mut().zip(x_tail) {
153            *y += self.alpha * x;
154        }
155    }
156}
157
158struct AxpyF64<'a> {
159    alpha: f64,
160    x: &'a [f64],
161    y: &'a mut [f64],
162}
163
164impl WithSimd for AxpyF64<'_> {
165    type Output = ();
166
167    fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
168        let (x_head, x_tail) = S::as_simd_f64s(self.x);
169        let (y_head, y_tail) = S::as_mut_simd_f64s(self.y);
170        let alpha = simd.splat_f64s(self.alpha);
171        for (y, &x) in y_head.iter_mut().zip(x_head) {
172            *y = simd.mul_add_f64s(alpha, x, *y);
173        }
174        for (y, &x) in y_tail.iter_mut().zip(x_tail) {
175            *y += self.alpha * x;
176        }
177    }
178}
179
180/// Copy a rectangular view block into compact row-major storage.
181pub fn pack_block<T: Copy>(
182    a: ArrayView2<'_, T>,
183    row: usize,
184    col: usize,
185    rows: usize,
186    cols: usize,
187) -> Result<Array2<T>> {
188    if row > a.rows()
189        || col > a.cols()
190        || rows > a.rows().saturating_sub(row)
191        || cols > a.cols().saturating_sub(col)
192    {
193        return Err(Error::IndexOutOfBounds);
194    }
195    Ok(Array2::from_fn([rows, cols], |i, j| a[(row + i, col + j)]))
196}
197
198/// Copy a compact row-major block into a destination view at `(row, col)`.
199pub fn unpack_block<T: Copy>(
200    block: ArrayView2<'_, T>,
201    mut dst: ArrayViewMut2<'_, T>,
202    row: usize,
203    col: usize,
204) -> Result<()> {
205    if row > dst.rows()
206        || col > dst.cols()
207        || block.rows() > dst.rows().saturating_sub(row)
208        || block.cols() > dst.cols().saturating_sub(col)
209    {
210        return Err(Error::IndexOutOfBounds);
211    }
212    for i in 0..block.rows() {
213        for j in 0..block.cols() {
214            dst[(row + i, col + j)] = block[(i, j)];
215        }
216    }
217    Ok(())
218}
219
220/// Checked matrix multiplication with transpose flags:
221/// `C = alpha * op(A) * op(B) + beta * C`.
222pub fn gemm<T: Float>(
223    alpha: T,
224    a: ArrayView2<'_, T>,
225    trans_a: bool,
226    b: ArrayView2<'_, T>,
227    trans_b: bool,
228    beta: T,
229    c: ArrayViewMut2<'_, T>,
230) -> Result<()> {
231    let mut workspace = Workspace::new();
232    gemm_with_workspace(alpha, a, trans_a, b, trans_b, beta, c, &mut workspace)
233}
234
235/// Checked matrix multiplication using caller-provided reusable workspace:
236/// `C = alpha * op(A) * op(B) + beta * C`.
237#[allow(clippy::too_many_arguments)]
238pub fn gemm_with_workspace<T: Float>(
239    alpha: T,
240    a: ArrayView2<'_, T>,
241    trans_a: bool,
242    b: ArrayView2<'_, T>,
243    trans_b: bool,
244    beta: T,
245    c: ArrayViewMut2<'_, T>,
246    workspace: &mut Workspace<T>,
247) -> Result<()> {
248    gemm_blocked_workspace(
249        GemmBlocked {
250            alpha,
251            a,
252            trans_a,
253            b,
254            trans_b,
255            beta,
256            c,
257            block_size: 32,
258        },
259        workspace,
260    )
261}
262
263struct GemmBlocked<'a, 'b, 'c, T> {
264    alpha: T,
265    a: ArrayView2<'a, T>,
266    trans_a: bool,
267    b: ArrayView2<'b, T>,
268    trans_b: bool,
269    beta: T,
270    c: ArrayViewMut2<'c, T>,
271    block_size: usize,
272}
273
274fn gemm_blocked_workspace<T: Float>(
275    spec: GemmBlocked<'_, '_, '_, T>,
276    workspace: &mut Workspace<T>,
277) -> Result<()> {
278    let GemmBlocked {
279        alpha,
280        a,
281        trans_a,
282        b,
283        trans_b,
284        beta,
285        mut c,
286        block_size,
287    } = spec;
288    let (m, k_a) = if trans_a {
289        (a.cols(), a.rows())
290    } else {
291        (a.rows(), a.cols())
292    };
293    let (k_b, n) = if trans_b {
294        (b.cols(), b.rows())
295    } else {
296        (b.rows(), b.cols())
297    };
298    if k_a != k_b {
299        return Err(Error::shape(vec![m, k_a], vec![k_b, n]));
300    }
301    if c.shape() != [m, n] {
302        return Err(Error::shape(vec![m, n], c.shape()));
303    }
304    let block = block_size.max(1);
305
306    for i in 0..m {
307        for j in 0..n {
308            c[(i, j)] *= beta;
309        }
310    }
311
312    for i0 in (0..m).step_by(block) {
313        let ib = block.min(m - i0);
314        for p0 in (0..k_a).step_by(block) {
315            let pb = block.min(k_a - p0);
316            for j0 in (0..n).step_by(block) {
317                let jb = block.min(n - j0);
318                let (a_buffer, b_buffer) = workspace.two_buffers_mut(0, 1);
319                let a_block = a_buffer.zeros(ib * pb);
320                pack_op_block_into(a, trans_a, i0, p0, ib, pb, a_block);
321                let b_block = b_buffer.zeros(pb * jb);
322                pack_op_block_into(b, trans_b, p0, j0, pb, jb, b_block);
323                for i in (0..ib).step_by(4) {
324                    for j in (0..jb).step_by(4) {
325                        let rows = 4.min(ib - i);
326                        let cols = 4.min(jb - j);
327                        microkernel_4x4(
328                            alpha,
329                            PackedBlock {
330                                data: &a_block[i * pb..],
331                                rows,
332                                cols: pb,
333                            },
334                            PackedBlock {
335                                data: &b_block[j..],
336                                rows: pb,
337                                cols: jb,
338                            },
339                            &mut c,
340                            [i0 + i, j0 + j],
341                            cols,
342                        );
343                    }
344                }
345            }
346        }
347    }
348    Ok(())
349}
350
351fn pack_op_block_into<T: Float>(
352    a: ArrayView2<'_, T>,
353    trans: bool,
354    row: usize,
355    col: usize,
356    rows: usize,
357    cols: usize,
358    out: &mut [T],
359) {
360    for i in 0..rows {
361        for j in 0..cols {
362            out[i * cols + j] = if trans {
363                a[(col + j, row + i)]
364            } else {
365                a[(row + i, col + j)]
366            };
367        }
368    }
369}
370
371struct PackedBlock<'a, T> {
372    data: &'a [T],
373    rows: usize,
374    cols: usize,
375}
376
377fn microkernel_4x4<T: Float>(
378    alpha: T,
379    a: PackedBlock<'_, T>,
380    b: PackedBlock<'_, T>,
381    c: &mut ArrayViewMut2<'_, T>,
382    c_origin: [usize; 2],
383    c_cols: usize,
384) {
385    let mut c00 = T::zero();
386    let mut c01 = T::zero();
387    let mut c02 = T::zero();
388    let mut c03 = T::zero();
389    let mut c10 = T::zero();
390    let mut c11 = T::zero();
391    let mut c12 = T::zero();
392    let mut c13 = T::zero();
393    let mut c20 = T::zero();
394    let mut c21 = T::zero();
395    let mut c22 = T::zero();
396    let mut c23 = T::zero();
397    let mut c30 = T::zero();
398    let mut c31 = T::zero();
399    let mut c32 = T::zero();
400    let mut c33 = T::zero();
401
402    for p in 0..a.cols {
403        let b0 = b.data[p * b.cols];
404        let b1 = if c_cols > 1 {
405            b.data[p * b.cols + 1]
406        } else {
407            T::zero()
408        };
409        let b2 = if c_cols > 2 {
410            b.data[p * b.cols + 2]
411        } else {
412            T::zero()
413        };
414        let b3 = if c_cols > 3 {
415            b.data[p * b.cols + 3]
416        } else {
417            T::zero()
418        };
419
420        let a0 = a.data[p];
421        c00 += a0 * b0;
422        c01 += a0 * b1;
423        c02 += a0 * b2;
424        c03 += a0 * b3;
425
426        if a.rows > 1 {
427            let a1 = a.data[a.cols + p];
428            c10 += a1 * b0;
429            c11 += a1 * b1;
430            c12 += a1 * b2;
431            c13 += a1 * b3;
432        }
433        if a.rows > 2 {
434            let a2 = a.data[2 * a.cols + p];
435            c20 += a2 * b0;
436            c21 += a2 * b1;
437            c22 += a2 * b2;
438            c23 += a2 * b3;
439        }
440        if a.rows > 3 {
441            let a3 = a.data[3 * a.cols + p];
442            c30 += a3 * b0;
443            c31 += a3 * b1;
444            c32 += a3 * b2;
445            c33 += a3 * b3;
446        }
447    }
448
449    accumulate_tile(alpha, c, c_origin, 0, &[c00, c01, c02, c03], c_cols);
450    if a.rows > 1 {
451        accumulate_tile(alpha, c, c_origin, 1, &[c10, c11, c12, c13], c_cols);
452    }
453    if a.rows > 2 {
454        accumulate_tile(alpha, c, c_origin, 2, &[c20, c21, c22, c23], c_cols);
455    }
456    if a.rows > 3 {
457        accumulate_tile(alpha, c, c_origin, 3, &[c30, c31, c32, c33], c_cols);
458    }
459}
460
461fn accumulate_tile<T: Float>(
462    alpha: T,
463    c: &mut ArrayViewMut2<'_, T>,
464    origin: [usize; 2],
465    row: usize,
466    values: &[T; 4],
467    cols: usize,
468) {
469    for col in 0..cols {
470        c[(origin[0] + row, origin[1] + col)] += alpha * values[col];
471    }
472}
473
474/// Return `A * B`.
475pub fn matmul<T: Float>(a: ArrayView2<'_, T>, b: ArrayView2<'_, T>) -> Result<Array2<T>> {
476    if a.cols() != b.rows() {
477        return Err(Error::shape(a.shape(), b.shape()));
478    }
479    let mut c = Array2::zeros([a.rows(), b.cols()]);
480    gemm(T::one(), a, false, b, false, T::zero(), c.view_mut())?;
481    Ok(c)
482}
483
484/// Matrix-like object usable by algorithms without requiring materialization.
485pub trait LinearOperator<T: Float> {
486    /// Number of rows.
487    fn rows(&self) -> usize;
488    /// Number of columns.
489    fn cols(&self) -> usize;
490
491    /// Compute `y = A x`.
492    fn matvec(&self, x: &[T], y: &mut [T]) -> Result<()>;
493
494    /// Compute `y = A^T x`.
495    fn t_matvec(&self, x: &[T], y: &mut [T]) -> Result<()>;
496
497    /// Compute `Y = A X`.
498    fn matmat(&self, x: ArrayView2<'_, T>, mut y: ArrayViewMut2<'_, T>) -> Result<()> {
499        if x.rows() != self.cols() || y.shape() != [self.rows(), x.cols()] {
500            return Err(Error::shape(vec![self.cols(), x.cols()], x.shape()));
501        }
502        for col in 0..x.cols() {
503            let mut input = vec![T::zero(); x.rows()];
504            let mut output = vec![T::zero(); y.rows()];
505            for row in 0..x.rows() {
506                input[row] = x[(row, col)];
507            }
508            self.matvec(&input, &mut output)?;
509            for row in 0..y.rows() {
510                y[(row, col)] = output[row];
511            }
512        }
513        Ok(())
514    }
515
516    /// Compute `Y = A^T X`.
517    fn t_matmat(&self, x: ArrayView2<'_, T>, mut y: ArrayViewMut2<'_, T>) -> Result<()> {
518        if x.rows() != self.rows() || y.shape() != [self.cols(), x.cols()] {
519            return Err(Error::shape(vec![self.rows(), x.cols()], x.shape()));
520        }
521        for col in 0..x.cols() {
522            let mut input = vec![T::zero(); x.rows()];
523            let mut output = vec![T::zero(); y.rows()];
524            for row in 0..x.rows() {
525                input[row] = x[(row, col)];
526            }
527            self.t_matvec(&input, &mut output)?;
528            for row in 0..y.rows() {
529                y[(row, col)] = output[row];
530            }
531        }
532        Ok(())
533    }
534}
535
536impl<T: Float> LinearOperator<T> for Array2<T> {
537    fn rows(&self) -> usize {
538        self.rows()
539    }
540
541    fn cols(&self) -> usize {
542        self.cols()
543    }
544
545    fn matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
546        if x.len() != self.cols() || y.len() != self.rows() {
547            return Err(Error::shape(
548                vec![self.cols(), self.rows()],
549                vec![x.len(), y.len()],
550            ));
551        }
552        for i in 0..self.rows() {
553            let mut sum = T::zero();
554            for j in 0..self.cols() {
555                sum += self[(i, j)] * x[j];
556            }
557            y[i] = sum;
558        }
559        Ok(())
560    }
561
562    fn t_matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
563        if x.len() != self.rows() || y.len() != self.cols() {
564            return Err(Error::shape(
565                vec![self.rows(), self.cols()],
566                vec![x.len(), y.len()],
567            ));
568        }
569        y.fill(T::zero());
570        for i in 0..self.rows() {
571            for j in 0..self.cols() {
572                y[j] += self[(i, j)] * x[i];
573            }
574        }
575        Ok(())
576    }
577
578    fn matmat(&self, x: ArrayView2<'_, T>, y: ArrayViewMut2<'_, T>) -> Result<()> {
579        gemm(T::one(), self.view(), false, x, false, T::zero(), y)
580    }
581
582    fn t_matmat(&self, x: ArrayView2<'_, T>, y: ArrayViewMut2<'_, T>) -> Result<()> {
583        gemm(T::one(), self.view(), true, x, false, T::zero(), y)
584    }
585}
586
587impl<T: Float> LinearOperator<T> for ArrayView2<'_, T> {
588    fn rows(&self) -> usize {
589        self.rows()
590    }
591
592    fn cols(&self) -> usize {
593        self.cols()
594    }
595
596    fn matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
597        if x.len() != self.cols() || y.len() != self.rows() {
598            return Err(Error::shape(
599                vec![self.cols(), self.rows()],
600                vec![x.len(), y.len()],
601            ));
602        }
603        for i in 0..self.rows() {
604            let mut sum = T::zero();
605            for j in 0..self.cols() {
606                sum += self[(i, j)] * x[j];
607            }
608            y[i] = sum;
609        }
610        Ok(())
611    }
612
613    fn t_matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
614        if x.len() != self.rows() || y.len() != self.cols() {
615            return Err(Error::shape(
616                vec![self.rows(), self.cols()],
617                vec![x.len(), y.len()],
618            ));
619        }
620        y.fill(T::zero());
621        for i in 0..self.rows() {
622            for j in 0..self.cols() {
623                y[j] += self[(i, j)] * x[i];
624            }
625        }
626        Ok(())
627    }
628
629    fn matmat(&self, x: ArrayView2<'_, T>, y: ArrayViewMut2<'_, T>) -> Result<()> {
630        gemm(T::one(), *self, false, x, false, T::zero(), y)
631    }
632
633    fn t_matmat(&self, x: ArrayView2<'_, T>, y: ArrayViewMut2<'_, T>) -> Result<()> {
634        gemm(T::one(), *self, true, x, false, T::zero(), y)
635    }
636}
637
638/// Lazy transpose wrapper.
639#[derive(Clone, Copy, Debug)]
640pub struct Transpose<A> {
641    inner: A,
642}
643
644impl<A> Transpose<A> {
645    /// Wrap an operator as its transpose.
646    pub fn new(inner: A) -> Self {
647        Self { inner }
648    }
649}
650
651impl<T: Float, A: LinearOperator<T>> LinearOperator<T> for Transpose<A> {
652    fn rows(&self) -> usize {
653        self.inner.cols()
654    }
655
656    fn cols(&self) -> usize {
657        self.inner.rows()
658    }
659
660    fn matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
661        self.inner.t_matvec(x, y)
662    }
663
664    fn t_matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
665        self.inner.matvec(x, y)
666    }
667}
668
669/// Lazy column-centered operator.
670#[derive(Clone, Debug)]
671pub struct CenteredOperator<A, T> {
672    inner: A,
673    means: Vec<T>,
674}
675
676impl<A, T: Float> CenteredOperator<A, T> {
677    /// Create a column-centered wrapper with one mean per column.
678    pub fn new(inner: A, means: Vec<T>) -> Self {
679        Self { inner, means }
680    }
681
682    /// Borrow the column means used by this operator.
683    pub fn means(&self) -> &[T] {
684        &self.means
685    }
686}
687
688impl<T: Float, A: LinearOperator<T>> LinearOperator<T> for CenteredOperator<A, T> {
689    fn rows(&self) -> usize {
690        self.inner.rows()
691    }
692
693    fn cols(&self) -> usize {
694        self.inner.cols()
695    }
696
697    fn matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
698        if self.means.len() != self.cols() {
699            return Err(Error::shape(vec![self.cols()], vec![self.means.len()]));
700        }
701        self.inner.matvec(x, y)?;
702        let correction: T = self.means.iter().zip(x).map(|(&mean, &xj)| mean * xj).sum();
703        for yi in y {
704            *yi -= correction;
705        }
706        Ok(())
707    }
708
709    fn t_matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
710        if self.means.len() != self.cols() {
711            return Err(Error::shape(vec![self.cols()], vec![self.means.len()]));
712        }
713        self.inner.t_matvec(x, y)?;
714        let total: T = x.iter().copied().sum();
715        for (yj, &mean) in y.iter_mut().zip(&self.means) {
716            *yj -= mean * total;
717        }
718        Ok(())
719    }
720}
721
722/// Lazy column-scaled operator representing `A * diag(scales)`.
723#[derive(Clone, Debug)]
724pub struct ColumnScaledOperator<A, T> {
725    inner: A,
726    scales: Vec<T>,
727}
728
729impl<A, T: Float> ColumnScaledOperator<A, T> {
730    /// Create a column-scaled wrapper with one scale per column.
731    pub fn new(inner: A, scales: Vec<T>) -> Self {
732        Self { inner, scales }
733    }
734
735    /// Borrow the column scales used by this operator.
736    pub fn scales(&self) -> &[T] {
737        &self.scales
738    }
739}
740
741impl<T: Float, A: LinearOperator<T>> LinearOperator<T> for ColumnScaledOperator<A, T> {
742    fn rows(&self) -> usize {
743        self.inner.rows()
744    }
745
746    fn cols(&self) -> usize {
747        self.inner.cols()
748    }
749
750    fn matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
751        if self.scales.len() != self.cols() {
752            return Err(Error::shape(vec![self.cols()], vec![self.scales.len()]));
753        }
754        if x.len() != self.cols() {
755            return Err(Error::shape(vec![self.cols()], vec![x.len()]));
756        }
757        let scaled = x
758            .iter()
759            .zip(&self.scales)
760            .map(|(&value, &scale)| value * scale)
761            .collect::<Vec<_>>();
762        self.inner.matvec(&scaled, y)
763    }
764
765    fn t_matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
766        if self.scales.len() != self.cols() {
767            return Err(Error::shape(vec![self.cols()], vec![self.scales.len()]));
768        }
769        self.inner.t_matvec(x, y)?;
770        for (yj, &scale) in y.iter_mut().zip(&self.scales) {
771            *yj *= scale;
772        }
773        Ok(())
774    }
775}
776
777/// Lazy row-scaled operator representing `diag(scales) * A`.
778#[derive(Clone, Debug)]
779pub struct RowScaledOperator<A, T> {
780    inner: A,
781    scales: Vec<T>,
782}
783
784impl<A, T: Float> RowScaledOperator<A, T> {
785    /// Create a row-scaled wrapper with one scale per row.
786    pub fn new(inner: A, scales: Vec<T>) -> Self {
787        Self { inner, scales }
788    }
789
790    /// Borrow the row scales used by this operator.
791    pub fn scales(&self) -> &[T] {
792        &self.scales
793    }
794}
795
796impl<T: Float, A: LinearOperator<T>> LinearOperator<T> for RowScaledOperator<A, T> {
797    fn rows(&self) -> usize {
798        self.inner.rows()
799    }
800
801    fn cols(&self) -> usize {
802        self.inner.cols()
803    }
804
805    fn matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
806        if self.scales.len() != self.rows() {
807            return Err(Error::shape(vec![self.rows()], vec![self.scales.len()]));
808        }
809        self.inner.matvec(x, y)?;
810        for (yi, &scale) in y.iter_mut().zip(&self.scales) {
811            *yi *= scale;
812        }
813        Ok(())
814    }
815
816    fn t_matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
817        if self.scales.len() != self.rows() {
818            return Err(Error::shape(vec![self.rows()], vec![self.scales.len()]));
819        }
820        if x.len() != self.rows() {
821            return Err(Error::shape(vec![self.rows()], vec![x.len()]));
822        }
823        let scaled = x
824            .iter()
825            .zip(&self.scales)
826            .map(|(&value, &scale)| value * scale)
827            .collect::<Vec<_>>();
828        self.inner.t_matvec(&scaled, y)
829    }
830}
831
832/// Lazy column-standardized operator representing `(A - means) * diag(1 / scales)`.
833#[derive(Clone, Debug)]
834pub struct StandardizedOperator<A, T> {
835    inner: A,
836    means: Vec<T>,
837    scales: Vec<T>,
838}
839
840impl<A, T: Float> StandardizedOperator<A, T> {
841    /// Create a column-standardized wrapper with one mean and scale per column.
842    pub fn new(inner: A, means: Vec<T>, scales: Vec<T>) -> Self {
843        Self {
844            inner,
845            means,
846            scales,
847        }
848    }
849
850    /// Borrow the column means used by this operator.
851    pub fn means(&self) -> &[T] {
852        &self.means
853    }
854
855    /// Borrow the column scales used by this operator.
856    pub fn scales(&self) -> &[T] {
857        &self.scales
858    }
859}
860
861impl<T: Float, A: LinearOperator<T>> LinearOperator<T> for StandardizedOperator<A, T> {
862    fn rows(&self) -> usize {
863        self.inner.rows()
864    }
865
866    fn cols(&self) -> usize {
867        self.inner.cols()
868    }
869
870    fn matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
871        validate_standardized_parts(self.cols(), &self.means, &self.scales)?;
872        if x.len() != self.cols() {
873            return Err(Error::shape(vec![self.cols()], vec![x.len()]));
874        }
875        let scaled = x
876            .iter()
877            .zip(&self.scales)
878            .map(|(&value, &scale)| value / scale)
879            .collect::<Vec<_>>();
880        self.inner.matvec(&scaled, y)?;
881        let correction: T = self
882            .means
883            .iter()
884            .zip(&scaled)
885            .map(|(&mean, &xj)| mean * xj)
886            .sum();
887        for yi in y {
888            *yi -= correction;
889        }
890        Ok(())
891    }
892
893    fn t_matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
894        validate_standardized_parts(self.cols(), &self.means, &self.scales)?;
895        self.inner.t_matvec(x, y)?;
896        let total: T = x.iter().copied().sum();
897        for ((yj, &mean), &scale) in y.iter_mut().zip(&self.means).zip(&self.scales) {
898            *yj = (*yj - mean * total) / scale;
899        }
900        Ok(())
901    }
902}
903
904fn validate_standardized_parts<T: Float>(cols: usize, means: &[T], scales: &[T]) -> Result<()> {
905    if means.len() != cols {
906        return Err(Error::shape(vec![cols], vec![means.len()]));
907    }
908    if scales.len() != cols {
909        return Err(Error::shape(vec![cols], vec![scales.len()]));
910    }
911    if scales.iter().any(|&scale| scale == T::zero()) {
912        return Err(Error::NumericalFailure("standardization scale is zero"));
913    }
914    Ok(())
915}
916
917/// Options for randomized SVD.
918#[derive(Clone, Debug)]
919pub struct RandomizedSvdOptions {
920    /// Target rank.
921    pub rank: usize,
922    /// Extra random vectors, usually 5-20.
923    pub oversampling: usize,
924    /// Power iterations.
925    pub power_iterations: usize,
926    /// Deterministic seed.
927    pub seed: Option<u64>,
928    /// Optional approximation tolerance reserved for future use.
929    pub tolerance: Option<f64>,
930    /// Whether to compute left singular vectors.
931    pub compute_u: bool,
932    /// Whether to compute right singular vectors.
933    pub compute_vt: bool,
934}
935
936impl Default for RandomizedSvdOptions {
937    fn default() -> Self {
938        Self {
939            rank: 2,
940            oversampling: 8,
941            power_iterations: 1,
942            seed: None,
943            tolerance: None,
944            compute_u: true,
945            compute_vt: true,
946        }
947    }
948}
949
950/// Singular value decomposition result.
951#[derive(Clone, Debug, PartialEq)]
952pub struct SvdResult<T> {
953    /// Left singular vectors.
954    pub u: Array2<T>,
955    /// Singular values.
956    pub s: Vec<T>,
957    /// Right singular vectors transposed.
958    pub vt: Array2<T>,
959}
960
961/// Symmetric eigendecomposition result.
962#[derive(Clone, Debug, PartialEq)]
963pub struct EighResult<T> {
964    /// Eigenvalues sorted in descending order.
965    pub eigenvalues: Vec<T>,
966    /// Eigenvectors stored as columns.
967    pub eigenvectors: Array2<T>,
968}
969
970/// Thin QR decomposition result.
971#[derive(Clone, Debug, PartialEq)]
972pub struct QrResult<T> {
973    /// Orthonormal basis columns.
974    pub q: Array2<T>,
975    /// Upper-triangular coefficients with shape `q.cols() x input.cols()`.
976    pub r: Array2<T>,
977}
978
979/// Modified Gram-Schmidt QR with one reorthogonalization pass.
980pub fn qr<T: Float>(a: ArrayView2<'_, T>) -> Result<QrResult<T>> {
981    let mut columns: Vec<Vec<T>> = Vec::new();
982    let mut r_rows: Vec<Vec<T>> = Vec::new();
983    for j in 0..a.cols() {
984        let mut column = (0..a.rows()).map(|i| a[(i, j)]).collect::<Vec<_>>();
985        for _ in 0..2 {
986            for (idx, prev) in columns.iter().enumerate() {
987                let mut projection = T::zero();
988                for i in 0..a.rows() {
989                    projection += prev[i] * column[i];
990                }
991                r_rows[idx][j] += projection;
992                for i in 0..a.rows() {
993                    column[i] -= projection * prev[i];
994                }
995            }
996        }
997        let mut norm = T::zero();
998        for &value in &column {
999            norm += value * value;
1000        }
1001        norm = norm.sqrt();
1002        if norm <= T::from_f64(1e-12) {
1003            continue;
1004        }
1005        if !norm.is_finite() {
1006            return Err(Error::NumericalFailure("non-finite QR column norm"));
1007        }
1008        for value in &mut column {
1009            *value /= norm;
1010        }
1011        columns.push(column);
1012        let mut r_row = vec![T::zero(); a.cols()];
1013        r_row[j] = norm;
1014        r_rows.push(r_row);
1015    }
1016    if columns.is_empty() && a.cols() > 0 {
1017        return Err(Error::NumericalFailure(
1018            "matrix has no independent QR columns",
1019        ));
1020    }
1021    let q = Array2::from_fn([a.rows(), columns.len()], |i, j| columns[j][i]);
1022    let r = Array2::from_fn([r_rows.len(), a.cols()], |i, j| r_rows[i][j]);
1023    Ok(QrResult { q, r })
1024}
1025
1026/// Modified Gram-Schmidt thin QR. Returns only `Q`.
1027pub fn thin_qr<T: Float>(a: ArrayView2<'_, T>) -> Result<Array2<T>> {
1028    Ok(qr(a)?.q)
1029}
1030
1031/// Reorthogonalize approximate basis columns.
1032pub fn reorthogonalize<T: Float>(q: ArrayView2<'_, T>) -> Result<Array2<T>> {
1033    thin_qr(q)
1034}
1035
1036/// Randomized range finder returning an orthonormal basis `Q`.
1037pub fn randomized_range_finder<T: Float, A: LinearOperator<T>>(
1038    a: &A,
1039    rank: usize,
1040    oversampling: usize,
1041    power_iterations: usize,
1042    seed: Option<u64>,
1043) -> Result<Array2<T>> {
1044    let l = (rank + oversampling).min(a.cols()).min(a.rows());
1045    if rank == 0 || rank > a.rows().min(a.cols()) {
1046        return Err(Error::RankTooLarge {
1047            requested: rank,
1048            max: a.rows().min(a.cols()),
1049        });
1050    }
1051    let mut rng = SmallRng::new(seed.unwrap_or(0x5eed_1234_abcd_9876));
1052    let omega = Array2::from_fn([a.cols(), l], |_, _| rng.normal::<T>());
1053    let mut y = Array2::zeros([a.rows(), l]);
1054    a.matmat(omega.view(), y.view_mut())?;
1055
1056    for _ in 0..power_iterations {
1057        let q = thin_qr(y.view())?;
1058        let mut z = Array2::zeros([a.cols(), q.cols()]);
1059        a.t_matmat(q.view(), z.view_mut())?;
1060        y = Array2::zeros([a.rows(), z.cols()]);
1061        a.matmat(z.view(), y.view_mut())?;
1062    }
1063
1064    thin_qr(y.view())
1065}
1066
1067/// Pure-Rust randomized SVD entry point.
1068pub fn randomized_svd<T: Float, A: LinearOperator<T>>(
1069    a: &A,
1070    options: RandomizedSvdOptions,
1071) -> Result<SvdResult<T>> {
1072    if options.rank == 0 || options.rank > a.rows().min(a.cols()) {
1073        return Err(Error::RankTooLarge {
1074            requested: options.rank,
1075            max: a.rows().min(a.cols()),
1076        });
1077    }
1078    let q = randomized_range_finder(
1079        a,
1080        options.rank,
1081        options.oversampling,
1082        options.power_iterations,
1083        options.seed,
1084    )?;
1085    let mut at_q = Array2::zeros([a.cols(), q.cols()]);
1086    a.t_matmat(q.view(), at_q.view_mut())?;
1087    let b = Array2::clone_contiguous(at_q.transpose_view());
1088    let small = svd_small(b.view())?;
1089    let rank = options.rank.min(small.s.len());
1090    let u = if options.compute_u {
1091        let projected = matmul(q.view(), small.u.view())?;
1092        Array2::from_fn([a.rows(), rank], |i, j| projected[(i, j)])
1093    } else {
1094        Array2::zeros([0, 0])
1095    };
1096    let s = small.s.into_iter().take(rank).collect();
1097    let vt = if options.compute_vt {
1098        Array2::from_fn([rank, a.cols()], |i, j| small.vt[(i, j)])
1099    } else {
1100        Array2::zeros([0, 0])
1101    };
1102    Ok(SvdResult { u, s, vt })
1103}
1104
1105/// Run randomized SVD on a dense view and return its reconstruction error.
1106///
1107/// When `options.tolerance` is set, the function returns `Error::NotConverged`
1108/// if the Frobenius reconstruction error is larger than the tolerance. The
1109/// residual requires both singular vector factors internally, but the returned
1110/// result still honors `compute_u` and `compute_vt`.
1111pub fn randomized_svd_with_error<T: Float>(
1112    a: ArrayView2<'_, T>,
1113    options: RandomizedSvdOptions,
1114) -> Result<(SvdResult<T>, T)> {
1115    let compute_u = options.compute_u;
1116    let compute_vt = options.compute_vt;
1117    let work_options = RandomizedSvdOptions {
1118        compute_u: true,
1119        compute_vt: true,
1120        ..options.clone()
1121    };
1122    let mut result = randomized_svd(&a, work_options)?;
1123    let error = approx_reconstruction_error(a, result.u.view(), &result.s, result.vt.view())?;
1124    if let Some(tolerance) = options.tolerance
1125        && error.to_f64() > tolerance
1126    {
1127        return Err(Error::NotConverged);
1128    }
1129    if !compute_u {
1130        result.u = Array2::zeros([0, 0]);
1131    }
1132    if !compute_vt {
1133        result.vt = Array2::zeros([0, 0]);
1134    }
1135    Ok((result, error))
1136}
1137
1138/// Run randomized SVD independently over each 2D slice of a 3D tensor.
1139pub fn batch_randomized_svd<T: Float>(
1140    a: ArrayView3<'_, T>,
1141    axis: Axis3,
1142    options: RandomizedSvdOptions,
1143) -> Result<Vec<SvdResult<T>>> {
1144    let axis_index = axis.index();
1145    let mut results = Vec::with_capacity(a.shape()[axis_index]);
1146    for index in 0..a.shape()[axis_index] {
1147        let matrix = a.matrix_at(axis_index, index)?;
1148        results.push(randomized_svd(&matrix, options.clone())?);
1149    }
1150    Ok(results)
1151}
1152
1153/// Run randomized SVD independently over each 2D slice using Rayon.
1154///
1155/// Results are returned in axis order, matching [`batch_randomized_svd`].
1156pub fn batch_randomized_svd_parallel<T: Float>(
1157    a: ArrayView3<'_, T>,
1158    axis: Axis3,
1159    options: RandomizedSvdOptions,
1160) -> Result<Vec<SvdResult<T>>> {
1161    let axis_index = axis.index();
1162    (0..a.shape()[axis_index])
1163        .into_par_iter()
1164        .map(|index| {
1165            let matrix = a.matrix_at(axis_index, index)?;
1166            randomized_svd(&matrix, options.clone())
1167        })
1168        .collect()
1169}
1170
1171/// Frobenius norm of `A - U diag(S) Vt`.
1172pub fn approx_reconstruction_error<T: Float>(
1173    a: ArrayView2<'_, T>,
1174    u: ArrayView2<'_, T>,
1175    s: &[T],
1176    vt: ArrayView2<'_, T>,
1177) -> Result<T> {
1178    if u.rows() != a.rows() || u.cols() != s.len() {
1179        return Err(Error::shape(vec![a.rows(), s.len()], u.shape()));
1180    }
1181    if vt.rows() != s.len() || vt.cols() != a.cols() {
1182        return Err(Error::shape(vec![s.len(), a.cols()], vt.shape()));
1183    }
1184
1185    let mut residual = T::zero();
1186    for i in 0..a.rows() {
1187        for j in 0..a.cols() {
1188            let mut approx = T::zero();
1189            for r in 0..s.len() {
1190                approx += u[(i, r)] * s[r] * vt[(r, j)];
1191            }
1192            let diff = a[(i, j)] - approx;
1193            residual += diff * diff;
1194        }
1195    }
1196    Ok(residual.sqrt())
1197}
1198
1199/// Fraction of squared singular value energy explained by each value.
1200pub fn explained_variance_ratio<T: Float>(s: &[T]) -> Vec<T> {
1201    let total: T = s.iter().copied().map(|value| value * value).sum();
1202    if total == T::zero() {
1203        return vec![T::zero(); s.len()];
1204    }
1205    s.iter()
1206        .copied()
1207        .map(|value| value * value / total)
1208        .collect()
1209}
1210
1211/// Symmetric eigendecomposition for small dense matrices.
1212pub fn eigh_small<T: Float>(a: ArrayView2<'_, T>) -> Result<EighResult<T>> {
1213    if a.rows() != a.cols() {
1214        return Err(Error::shape([a.rows(), a.rows()], a.shape()));
1215    }
1216    for i in 0..a.rows() {
1217        for j in (i + 1)..a.cols() {
1218            if (a[(i, j)] - a[(j, i)]).abs() > T::from_f64(1e-9) {
1219                return Err(Error::NumericalFailure("matrix is not symmetric"));
1220            }
1221        }
1222    }
1223
1224    let mut eig = jacobi_symmetric(Array2::from_fn(a.shape(), |i, j| a[(i, j)].to_f64()))?;
1225    eig.sort_by(|left, right| {
1226        right
1227            .0
1228            .partial_cmp(&left.0)
1229            .unwrap_or(core::cmp::Ordering::Equal)
1230    });
1231
1232    let eigenvalues = eig
1233        .iter()
1234        .map(|(value, _)| T::from_f64(*value))
1235        .collect::<Vec<_>>();
1236    let eigenvectors = Array2::from_fn([a.rows(), a.cols()], |i, j| T::from_f64(eig[j].1[i]));
1237    Ok(EighResult {
1238        eigenvalues,
1239        eigenvectors,
1240    })
1241}
1242
1243/// Singular value decomposition for small dense matrices.
1244pub fn svd_small<T: Float>(a: ArrayView2<'_, T>) -> Result<SvdResult<T>> {
1245    let gram = gram_left(a);
1246    let mut eig = jacobi_symmetric(gram)?;
1247    eig.sort_by(|left, right| {
1248        right
1249            .0
1250            .partial_cmp(&left.0)
1251            .unwrap_or(core::cmp::Ordering::Equal)
1252    });
1253
1254    let rank = eig.len().min(a.rows()).min(a.cols());
1255    let mut u = Array2::zeros([a.rows(), rank]);
1256    let mut s = vec![T::zero(); rank];
1257    for j in 0..rank {
1258        let value = eig[j].0.max(0.0).sqrt();
1259        s[j] = T::from_f64(value);
1260        for i in 0..a.rows() {
1261            u[(i, j)] = T::from_f64(eig[j].1[i]);
1262        }
1263    }
1264
1265    let mut vt = Array2::zeros([rank, a.cols()]);
1266    for r in 0..rank {
1267        if s[r] <= T::from_f64(1e-12) {
1268            continue;
1269        }
1270        for col in 0..a.cols() {
1271            let mut value = T::zero();
1272            for row in 0..a.rows() {
1273                value += u[(row, r)] * a[(row, col)];
1274            }
1275            vt[(r, col)] = value / s[r];
1276        }
1277    }
1278    Ok(SvdResult { u, s, vt })
1279}
1280
1281fn gram_left<T: Float>(a: ArrayView2<'_, T>) -> Array2<f64> {
1282    Array2::from_fn([a.rows(), a.rows()], |i, j| {
1283        let mut sum = 0.0;
1284        for col in 0..a.cols() {
1285            sum += a[(i, col)].to_f64() * a[(j, col)].to_f64();
1286        }
1287        sum
1288    })
1289}
1290
1291fn jacobi_symmetric(mut a: Array2<f64>) -> Result<Vec<(f64, Vec<f64>)>> {
1292    if a.rows() != a.cols() {
1293        return Err(Error::shape([a.rows(), a.rows()], a.shape()));
1294    }
1295    let n = a.rows();
1296    let mut v = Array2::from_fn([n, n], |i, j| if i == j { 1.0 } else { 0.0 });
1297    let max_iter = 64usize.saturating_mul(n.max(1)).saturating_mul(n.max(1));
1298
1299    for _ in 0..max_iter {
1300        let mut p = 0;
1301        let mut q = 0;
1302        let mut max = 0.0;
1303        for i in 0..n {
1304            for j in (i + 1)..n {
1305                let value = a[(i, j)].abs();
1306                if value > max {
1307                    max = value;
1308                    p = i;
1309                    q = j;
1310                }
1311            }
1312        }
1313        if max < 1e-12 {
1314            let mut result = Vec::with_capacity(n);
1315            for col in 0..n {
1316                let mut vector = Vec::with_capacity(n);
1317                for row in 0..n {
1318                    vector.push(v[(row, col)]);
1319                }
1320                result.push((a[(col, col)], vector));
1321            }
1322            return Ok(result);
1323        }
1324
1325        let app = a[(p, p)];
1326        let aqq = a[(q, q)];
1327        let apq = a[(p, q)];
1328        let tau = (aqq - app) / (2.0 * apq);
1329        let t = tau.signum() / (tau.abs() + (1.0 + tau * tau).sqrt());
1330        let c = 1.0 / (1.0 + t * t).sqrt();
1331        let s = t * c;
1332
1333        for k in 0..n {
1334            if k != p && k != q {
1335                let akp = a[(k, p)];
1336                let akq = a[(k, q)];
1337                let new_kp = c * akp - s * akq;
1338                let new_kq = s * akp + c * akq;
1339                a[(k, p)] = new_kp;
1340                a[(p, k)] = new_kp;
1341                a[(k, q)] = new_kq;
1342                a[(q, k)] = new_kq;
1343            }
1344        }
1345
1346        a[(p, p)] = c * c * app - 2.0 * s * c * apq + s * s * aqq;
1347        a[(q, q)] = s * s * app + 2.0 * s * c * apq + c * c * aqq;
1348        a[(p, q)] = 0.0;
1349        a[(q, p)] = 0.0;
1350
1351        for k in 0..n {
1352            let vkp = v[(k, p)];
1353            let vkq = v[(k, q)];
1354            v[(k, p)] = c * vkp - s * vkq;
1355            v[(k, q)] = s * vkp + c * vkq;
1356        }
1357    }
1358
1359    Err(Error::NotConverged)
1360}