1use crate::array2::Array2;
4use crate::array3::Axis3;
5use crate::error::{Error, Result};
6use crate::numeric::Float;
7use crate::rand::SmallRng;
8use crate::view2::{ArrayView2, ArrayViewMut2};
9use crate::view3::ArrayView3;
10use crate::workspace::Workspace;
11use pulp::{Arch, Simd, WithSimd};
12use rayon::prelude::*;
13
14pub fn dot<T: Float>(x: &[T], y: &[T]) -> Result<T> {
16 if x.len() != y.len() {
17 return Err(Error::shape(vec![x.len()], vec![y.len()]));
18 }
19 Ok(x.iter().zip(y).map(|(&a, &b)| a * b).sum())
20}
21
22pub fn axpy<T: Float>(alpha: T, x: &[T], y: &mut [T]) -> Result<()> {
24 if x.len() != y.len() {
25 return Err(Error::shape(vec![x.len()], vec![y.len()]));
26 }
27 for (yi, &xi) in y.iter_mut().zip(x) {
28 *yi += alpha * xi;
29 }
30 Ok(())
31}
32
33pub fn norm_l2<T: Float>(x: &[T]) -> T {
35 x.iter()
36 .copied()
37 .map(|value| value * value)
38 .sum::<T>()
39 .sqrt()
40}
41
42pub fn dot_f32(x: &[f32], y: &[f32]) -> Result<f32> {
44 if x.len() != y.len() {
45 return Err(Error::shape(vec![x.len()], vec![y.len()]));
46 }
47 Ok(Arch::new().dispatch(DotF32 { x, y }))
48}
49
50pub fn dot_f64(x: &[f64], y: &[f64]) -> Result<f64> {
52 if x.len() != y.len() {
53 return Err(Error::shape(vec![x.len()], vec![y.len()]));
54 }
55 Ok(Arch::new().dispatch(DotF64 { x, y }))
56}
57
58pub fn axpy_f32(alpha: f32, x: &[f32], y: &mut [f32]) -> Result<()> {
60 if x.len() != y.len() {
61 return Err(Error::shape(vec![x.len()], vec![y.len()]));
62 }
63 Arch::new().dispatch(AxpyF32 { alpha, x, y });
64 Ok(())
65}
66
67pub fn axpy_f64(alpha: f64, x: &[f64], y: &mut [f64]) -> Result<()> {
69 if x.len() != y.len() {
70 return Err(Error::shape(vec![x.len()], vec![y.len()]));
71 }
72 Arch::new().dispatch(AxpyF64 { alpha, x, y });
73 Ok(())
74}
75
76pub fn norm_l2_f32(x: &[f32]) -> f32 {
78 dot_f32(x, x)
79 .expect("matching input slices are valid")
80 .sqrt()
81}
82
83pub fn norm_l2_f64(x: &[f64]) -> f64 {
85 dot_f64(x, x)
86 .expect("matching input slices are valid")
87 .sqrt()
88}
89
90struct DotF32<'a> {
91 x: &'a [f32],
92 y: &'a [f32],
93}
94
95impl WithSimd for DotF32<'_> {
96 type Output = f32;
97
98 fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
99 let (x_head, x_tail) = S::as_simd_f32s(self.x);
100 let (y_head, y_tail) = S::as_simd_f32s(self.y);
101 let mut acc = simd.splat_f32s(0.0);
102 for (&x, &y) in x_head.iter().zip(y_head) {
103 acc = simd.mul_add_f32s(x, y, acc);
104 }
105 let mut sum = simd.reduce_sum_f32s(acc);
106 for (&x, &y) in x_tail.iter().zip(y_tail) {
107 sum += x * y;
108 }
109 sum
110 }
111}
112
113struct DotF64<'a> {
114 x: &'a [f64],
115 y: &'a [f64],
116}
117
118impl WithSimd for DotF64<'_> {
119 type Output = f64;
120
121 fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
122 let (x_head, x_tail) = S::as_simd_f64s(self.x);
123 let (y_head, y_tail) = S::as_simd_f64s(self.y);
124 let mut acc = simd.splat_f64s(0.0);
125 for (&x, &y) in x_head.iter().zip(y_head) {
126 acc = simd.mul_add_f64s(x, y, acc);
127 }
128 let mut sum = simd.reduce_sum_f64s(acc);
129 for (&x, &y) in x_tail.iter().zip(y_tail) {
130 sum += x * y;
131 }
132 sum
133 }
134}
135
136struct AxpyF32<'a> {
137 alpha: f32,
138 x: &'a [f32],
139 y: &'a mut [f32],
140}
141
142impl WithSimd for AxpyF32<'_> {
143 type Output = ();
144
145 fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
146 let (x_head, x_tail) = S::as_simd_f32s(self.x);
147 let (y_head, y_tail) = S::as_mut_simd_f32s(self.y);
148 let alpha = simd.splat_f32s(self.alpha);
149 for (y, &x) in y_head.iter_mut().zip(x_head) {
150 *y = simd.mul_add_f32s(alpha, x, *y);
151 }
152 for (y, &x) in y_tail.iter_mut().zip(x_tail) {
153 *y += self.alpha * x;
154 }
155 }
156}
157
158struct AxpyF64<'a> {
159 alpha: f64,
160 x: &'a [f64],
161 y: &'a mut [f64],
162}
163
164impl WithSimd for AxpyF64<'_> {
165 type Output = ();
166
167 fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
168 let (x_head, x_tail) = S::as_simd_f64s(self.x);
169 let (y_head, y_tail) = S::as_mut_simd_f64s(self.y);
170 let alpha = simd.splat_f64s(self.alpha);
171 for (y, &x) in y_head.iter_mut().zip(x_head) {
172 *y = simd.mul_add_f64s(alpha, x, *y);
173 }
174 for (y, &x) in y_tail.iter_mut().zip(x_tail) {
175 *y += self.alpha * x;
176 }
177 }
178}
179
180pub fn pack_block<T: Copy>(
182 a: ArrayView2<'_, T>,
183 row: usize,
184 col: usize,
185 rows: usize,
186 cols: usize,
187) -> Result<Array2<T>> {
188 if row > a.rows()
189 || col > a.cols()
190 || rows > a.rows().saturating_sub(row)
191 || cols > a.cols().saturating_sub(col)
192 {
193 return Err(Error::IndexOutOfBounds);
194 }
195 Ok(Array2::from_fn([rows, cols], |i, j| a[(row + i, col + j)]))
196}
197
198pub fn unpack_block<T: Copy>(
200 block: ArrayView2<'_, T>,
201 mut dst: ArrayViewMut2<'_, T>,
202 row: usize,
203 col: usize,
204) -> Result<()> {
205 if row > dst.rows()
206 || col > dst.cols()
207 || block.rows() > dst.rows().saturating_sub(row)
208 || block.cols() > dst.cols().saturating_sub(col)
209 {
210 return Err(Error::IndexOutOfBounds);
211 }
212 for i in 0..block.rows() {
213 for j in 0..block.cols() {
214 dst[(row + i, col + j)] = block[(i, j)];
215 }
216 }
217 Ok(())
218}
219
220pub fn gemm<T: Float>(
223 alpha: T,
224 a: ArrayView2<'_, T>,
225 trans_a: bool,
226 b: ArrayView2<'_, T>,
227 trans_b: bool,
228 beta: T,
229 c: ArrayViewMut2<'_, T>,
230) -> Result<()> {
231 let mut workspace = Workspace::new();
232 gemm_with_workspace(alpha, a, trans_a, b, trans_b, beta, c, &mut workspace)
233}
234
235#[allow(clippy::too_many_arguments)]
238pub fn gemm_with_workspace<T: Float>(
239 alpha: T,
240 a: ArrayView2<'_, T>,
241 trans_a: bool,
242 b: ArrayView2<'_, T>,
243 trans_b: bool,
244 beta: T,
245 c: ArrayViewMut2<'_, T>,
246 workspace: &mut Workspace<T>,
247) -> Result<()> {
248 gemm_blocked_workspace(
249 GemmBlocked {
250 alpha,
251 a,
252 trans_a,
253 b,
254 trans_b,
255 beta,
256 c,
257 block_size: 32,
258 },
259 workspace,
260 )
261}
262
263struct GemmBlocked<'a, 'b, 'c, T> {
264 alpha: T,
265 a: ArrayView2<'a, T>,
266 trans_a: bool,
267 b: ArrayView2<'b, T>,
268 trans_b: bool,
269 beta: T,
270 c: ArrayViewMut2<'c, T>,
271 block_size: usize,
272}
273
274fn gemm_blocked_workspace<T: Float>(
275 spec: GemmBlocked<'_, '_, '_, T>,
276 workspace: &mut Workspace<T>,
277) -> Result<()> {
278 let GemmBlocked {
279 alpha,
280 a,
281 trans_a,
282 b,
283 trans_b,
284 beta,
285 mut c,
286 block_size,
287 } = spec;
288 let (m, k_a) = if trans_a {
289 (a.cols(), a.rows())
290 } else {
291 (a.rows(), a.cols())
292 };
293 let (k_b, n) = if trans_b {
294 (b.cols(), b.rows())
295 } else {
296 (b.rows(), b.cols())
297 };
298 if k_a != k_b {
299 return Err(Error::shape(vec![m, k_a], vec![k_b, n]));
300 }
301 if c.shape() != [m, n] {
302 return Err(Error::shape(vec![m, n], c.shape()));
303 }
304 let block = block_size.max(1);
305
306 for i in 0..m {
307 for j in 0..n {
308 c[(i, j)] *= beta;
309 }
310 }
311
312 for i0 in (0..m).step_by(block) {
313 let ib = block.min(m - i0);
314 for p0 in (0..k_a).step_by(block) {
315 let pb = block.min(k_a - p0);
316 for j0 in (0..n).step_by(block) {
317 let jb = block.min(n - j0);
318 let (a_buffer, b_buffer) = workspace.two_buffers_mut(0, 1);
319 let a_block = a_buffer.zeros(ib * pb);
320 pack_op_block_into(a, trans_a, i0, p0, ib, pb, a_block);
321 let b_block = b_buffer.zeros(pb * jb);
322 pack_op_block_into(b, trans_b, p0, j0, pb, jb, b_block);
323 for i in (0..ib).step_by(4) {
324 for j in (0..jb).step_by(4) {
325 let rows = 4.min(ib - i);
326 let cols = 4.min(jb - j);
327 microkernel_4x4(
328 alpha,
329 PackedBlock {
330 data: &a_block[i * pb..],
331 rows,
332 cols: pb,
333 },
334 PackedBlock {
335 data: &b_block[j..],
336 rows: pb,
337 cols: jb,
338 },
339 &mut c,
340 [i0 + i, j0 + j],
341 cols,
342 );
343 }
344 }
345 }
346 }
347 }
348 Ok(())
349}
350
351fn pack_op_block_into<T: Float>(
352 a: ArrayView2<'_, T>,
353 trans: bool,
354 row: usize,
355 col: usize,
356 rows: usize,
357 cols: usize,
358 out: &mut [T],
359) {
360 for i in 0..rows {
361 for j in 0..cols {
362 out[i * cols + j] = if trans {
363 a[(col + j, row + i)]
364 } else {
365 a[(row + i, col + j)]
366 };
367 }
368 }
369}
370
371struct PackedBlock<'a, T> {
372 data: &'a [T],
373 rows: usize,
374 cols: usize,
375}
376
377fn microkernel_4x4<T: Float>(
378 alpha: T,
379 a: PackedBlock<'_, T>,
380 b: PackedBlock<'_, T>,
381 c: &mut ArrayViewMut2<'_, T>,
382 c_origin: [usize; 2],
383 c_cols: usize,
384) {
385 let mut c00 = T::zero();
386 let mut c01 = T::zero();
387 let mut c02 = T::zero();
388 let mut c03 = T::zero();
389 let mut c10 = T::zero();
390 let mut c11 = T::zero();
391 let mut c12 = T::zero();
392 let mut c13 = T::zero();
393 let mut c20 = T::zero();
394 let mut c21 = T::zero();
395 let mut c22 = T::zero();
396 let mut c23 = T::zero();
397 let mut c30 = T::zero();
398 let mut c31 = T::zero();
399 let mut c32 = T::zero();
400 let mut c33 = T::zero();
401
402 for p in 0..a.cols {
403 let b0 = b.data[p * b.cols];
404 let b1 = if c_cols > 1 {
405 b.data[p * b.cols + 1]
406 } else {
407 T::zero()
408 };
409 let b2 = if c_cols > 2 {
410 b.data[p * b.cols + 2]
411 } else {
412 T::zero()
413 };
414 let b3 = if c_cols > 3 {
415 b.data[p * b.cols + 3]
416 } else {
417 T::zero()
418 };
419
420 let a0 = a.data[p];
421 c00 += a0 * b0;
422 c01 += a0 * b1;
423 c02 += a0 * b2;
424 c03 += a0 * b3;
425
426 if a.rows > 1 {
427 let a1 = a.data[a.cols + p];
428 c10 += a1 * b0;
429 c11 += a1 * b1;
430 c12 += a1 * b2;
431 c13 += a1 * b3;
432 }
433 if a.rows > 2 {
434 let a2 = a.data[2 * a.cols + p];
435 c20 += a2 * b0;
436 c21 += a2 * b1;
437 c22 += a2 * b2;
438 c23 += a2 * b3;
439 }
440 if a.rows > 3 {
441 let a3 = a.data[3 * a.cols + p];
442 c30 += a3 * b0;
443 c31 += a3 * b1;
444 c32 += a3 * b2;
445 c33 += a3 * b3;
446 }
447 }
448
449 accumulate_tile(alpha, c, c_origin, 0, &[c00, c01, c02, c03], c_cols);
450 if a.rows > 1 {
451 accumulate_tile(alpha, c, c_origin, 1, &[c10, c11, c12, c13], c_cols);
452 }
453 if a.rows > 2 {
454 accumulate_tile(alpha, c, c_origin, 2, &[c20, c21, c22, c23], c_cols);
455 }
456 if a.rows > 3 {
457 accumulate_tile(alpha, c, c_origin, 3, &[c30, c31, c32, c33], c_cols);
458 }
459}
460
461fn accumulate_tile<T: Float>(
462 alpha: T,
463 c: &mut ArrayViewMut2<'_, T>,
464 origin: [usize; 2],
465 row: usize,
466 values: &[T; 4],
467 cols: usize,
468) {
469 for col in 0..cols {
470 c[(origin[0] + row, origin[1] + col)] += alpha * values[col];
471 }
472}
473
474pub fn matmul<T: Float>(a: ArrayView2<'_, T>, b: ArrayView2<'_, T>) -> Result<Array2<T>> {
476 if a.cols() != b.rows() {
477 return Err(Error::shape(a.shape(), b.shape()));
478 }
479 let mut c = Array2::zeros([a.rows(), b.cols()]);
480 gemm(T::one(), a, false, b, false, T::zero(), c.view_mut())?;
481 Ok(c)
482}
483
484pub trait LinearOperator<T: Float> {
486 fn rows(&self) -> usize;
488 fn cols(&self) -> usize;
490
491 fn matvec(&self, x: &[T], y: &mut [T]) -> Result<()>;
493
494 fn t_matvec(&self, x: &[T], y: &mut [T]) -> Result<()>;
496
497 fn matmat(&self, x: ArrayView2<'_, T>, mut y: ArrayViewMut2<'_, T>) -> Result<()> {
499 if x.rows() != self.cols() || y.shape() != [self.rows(), x.cols()] {
500 return Err(Error::shape(vec![self.cols(), x.cols()], x.shape()));
501 }
502 for col in 0..x.cols() {
503 let mut input = vec![T::zero(); x.rows()];
504 let mut output = vec![T::zero(); y.rows()];
505 for row in 0..x.rows() {
506 input[row] = x[(row, col)];
507 }
508 self.matvec(&input, &mut output)?;
509 for row in 0..y.rows() {
510 y[(row, col)] = output[row];
511 }
512 }
513 Ok(())
514 }
515
516 fn t_matmat(&self, x: ArrayView2<'_, T>, mut y: ArrayViewMut2<'_, T>) -> Result<()> {
518 if x.rows() != self.rows() || y.shape() != [self.cols(), x.cols()] {
519 return Err(Error::shape(vec![self.rows(), x.cols()], x.shape()));
520 }
521 for col in 0..x.cols() {
522 let mut input = vec![T::zero(); x.rows()];
523 let mut output = vec![T::zero(); y.rows()];
524 for row in 0..x.rows() {
525 input[row] = x[(row, col)];
526 }
527 self.t_matvec(&input, &mut output)?;
528 for row in 0..y.rows() {
529 y[(row, col)] = output[row];
530 }
531 }
532 Ok(())
533 }
534}
535
536impl<T: Float> LinearOperator<T> for Array2<T> {
537 fn rows(&self) -> usize {
538 self.rows()
539 }
540
541 fn cols(&self) -> usize {
542 self.cols()
543 }
544
545 fn matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
546 if x.len() != self.cols() || y.len() != self.rows() {
547 return Err(Error::shape(
548 vec![self.cols(), self.rows()],
549 vec![x.len(), y.len()],
550 ));
551 }
552 for i in 0..self.rows() {
553 let mut sum = T::zero();
554 for j in 0..self.cols() {
555 sum += self[(i, j)] * x[j];
556 }
557 y[i] = sum;
558 }
559 Ok(())
560 }
561
562 fn t_matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
563 if x.len() != self.rows() || y.len() != self.cols() {
564 return Err(Error::shape(
565 vec![self.rows(), self.cols()],
566 vec![x.len(), y.len()],
567 ));
568 }
569 y.fill(T::zero());
570 for i in 0..self.rows() {
571 for j in 0..self.cols() {
572 y[j] += self[(i, j)] * x[i];
573 }
574 }
575 Ok(())
576 }
577
578 fn matmat(&self, x: ArrayView2<'_, T>, y: ArrayViewMut2<'_, T>) -> Result<()> {
579 gemm(T::one(), self.view(), false, x, false, T::zero(), y)
580 }
581
582 fn t_matmat(&self, x: ArrayView2<'_, T>, y: ArrayViewMut2<'_, T>) -> Result<()> {
583 gemm(T::one(), self.view(), true, x, false, T::zero(), y)
584 }
585}
586
587impl<T: Float> LinearOperator<T> for ArrayView2<'_, T> {
588 fn rows(&self) -> usize {
589 self.rows()
590 }
591
592 fn cols(&self) -> usize {
593 self.cols()
594 }
595
596 fn matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
597 if x.len() != self.cols() || y.len() != self.rows() {
598 return Err(Error::shape(
599 vec![self.cols(), self.rows()],
600 vec![x.len(), y.len()],
601 ));
602 }
603 for i in 0..self.rows() {
604 let mut sum = T::zero();
605 for j in 0..self.cols() {
606 sum += self[(i, j)] * x[j];
607 }
608 y[i] = sum;
609 }
610 Ok(())
611 }
612
613 fn t_matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
614 if x.len() != self.rows() || y.len() != self.cols() {
615 return Err(Error::shape(
616 vec![self.rows(), self.cols()],
617 vec![x.len(), y.len()],
618 ));
619 }
620 y.fill(T::zero());
621 for i in 0..self.rows() {
622 for j in 0..self.cols() {
623 y[j] += self[(i, j)] * x[i];
624 }
625 }
626 Ok(())
627 }
628
629 fn matmat(&self, x: ArrayView2<'_, T>, y: ArrayViewMut2<'_, T>) -> Result<()> {
630 gemm(T::one(), *self, false, x, false, T::zero(), y)
631 }
632
633 fn t_matmat(&self, x: ArrayView2<'_, T>, y: ArrayViewMut2<'_, T>) -> Result<()> {
634 gemm(T::one(), *self, true, x, false, T::zero(), y)
635 }
636}
637
638#[derive(Clone, Copy, Debug)]
640pub struct Transpose<A> {
641 inner: A,
642}
643
644impl<A> Transpose<A> {
645 pub fn new(inner: A) -> Self {
647 Self { inner }
648 }
649}
650
651impl<T: Float, A: LinearOperator<T>> LinearOperator<T> for Transpose<A> {
652 fn rows(&self) -> usize {
653 self.inner.cols()
654 }
655
656 fn cols(&self) -> usize {
657 self.inner.rows()
658 }
659
660 fn matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
661 self.inner.t_matvec(x, y)
662 }
663
664 fn t_matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
665 self.inner.matvec(x, y)
666 }
667}
668
669#[derive(Clone, Debug)]
671pub struct CenteredOperator<A, T> {
672 inner: A,
673 means: Vec<T>,
674}
675
676impl<A, T: Float> CenteredOperator<A, T> {
677 pub fn new(inner: A, means: Vec<T>) -> Self {
679 Self { inner, means }
680 }
681
682 pub fn means(&self) -> &[T] {
684 &self.means
685 }
686}
687
688impl<T: Float, A: LinearOperator<T>> LinearOperator<T> for CenteredOperator<A, T> {
689 fn rows(&self) -> usize {
690 self.inner.rows()
691 }
692
693 fn cols(&self) -> usize {
694 self.inner.cols()
695 }
696
697 fn matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
698 if self.means.len() != self.cols() {
699 return Err(Error::shape(vec![self.cols()], vec![self.means.len()]));
700 }
701 self.inner.matvec(x, y)?;
702 let correction: T = self.means.iter().zip(x).map(|(&mean, &xj)| mean * xj).sum();
703 for yi in y {
704 *yi -= correction;
705 }
706 Ok(())
707 }
708
709 fn t_matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
710 if self.means.len() != self.cols() {
711 return Err(Error::shape(vec![self.cols()], vec![self.means.len()]));
712 }
713 self.inner.t_matvec(x, y)?;
714 let total: T = x.iter().copied().sum();
715 for (yj, &mean) in y.iter_mut().zip(&self.means) {
716 *yj -= mean * total;
717 }
718 Ok(())
719 }
720}
721
722#[derive(Clone, Debug)]
724pub struct ColumnScaledOperator<A, T> {
725 inner: A,
726 scales: Vec<T>,
727}
728
729impl<A, T: Float> ColumnScaledOperator<A, T> {
730 pub fn new(inner: A, scales: Vec<T>) -> Self {
732 Self { inner, scales }
733 }
734
735 pub fn scales(&self) -> &[T] {
737 &self.scales
738 }
739}
740
741impl<T: Float, A: LinearOperator<T>> LinearOperator<T> for ColumnScaledOperator<A, T> {
742 fn rows(&self) -> usize {
743 self.inner.rows()
744 }
745
746 fn cols(&self) -> usize {
747 self.inner.cols()
748 }
749
750 fn matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
751 if self.scales.len() != self.cols() {
752 return Err(Error::shape(vec![self.cols()], vec![self.scales.len()]));
753 }
754 if x.len() != self.cols() {
755 return Err(Error::shape(vec![self.cols()], vec![x.len()]));
756 }
757 let scaled = x
758 .iter()
759 .zip(&self.scales)
760 .map(|(&value, &scale)| value * scale)
761 .collect::<Vec<_>>();
762 self.inner.matvec(&scaled, y)
763 }
764
765 fn t_matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
766 if self.scales.len() != self.cols() {
767 return Err(Error::shape(vec![self.cols()], vec![self.scales.len()]));
768 }
769 self.inner.t_matvec(x, y)?;
770 for (yj, &scale) in y.iter_mut().zip(&self.scales) {
771 *yj *= scale;
772 }
773 Ok(())
774 }
775}
776
777#[derive(Clone, Debug)]
779pub struct RowScaledOperator<A, T> {
780 inner: A,
781 scales: Vec<T>,
782}
783
784impl<A, T: Float> RowScaledOperator<A, T> {
785 pub fn new(inner: A, scales: Vec<T>) -> Self {
787 Self { inner, scales }
788 }
789
790 pub fn scales(&self) -> &[T] {
792 &self.scales
793 }
794}
795
796impl<T: Float, A: LinearOperator<T>> LinearOperator<T> for RowScaledOperator<A, T> {
797 fn rows(&self) -> usize {
798 self.inner.rows()
799 }
800
801 fn cols(&self) -> usize {
802 self.inner.cols()
803 }
804
805 fn matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
806 if self.scales.len() != self.rows() {
807 return Err(Error::shape(vec![self.rows()], vec![self.scales.len()]));
808 }
809 self.inner.matvec(x, y)?;
810 for (yi, &scale) in y.iter_mut().zip(&self.scales) {
811 *yi *= scale;
812 }
813 Ok(())
814 }
815
816 fn t_matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
817 if self.scales.len() != self.rows() {
818 return Err(Error::shape(vec![self.rows()], vec![self.scales.len()]));
819 }
820 if x.len() != self.rows() {
821 return Err(Error::shape(vec![self.rows()], vec![x.len()]));
822 }
823 let scaled = x
824 .iter()
825 .zip(&self.scales)
826 .map(|(&value, &scale)| value * scale)
827 .collect::<Vec<_>>();
828 self.inner.t_matvec(&scaled, y)
829 }
830}
831
832#[derive(Clone, Debug)]
834pub struct StandardizedOperator<A, T> {
835 inner: A,
836 means: Vec<T>,
837 scales: Vec<T>,
838}
839
840impl<A, T: Float> StandardizedOperator<A, T> {
841 pub fn new(inner: A, means: Vec<T>, scales: Vec<T>) -> Self {
843 Self {
844 inner,
845 means,
846 scales,
847 }
848 }
849
850 pub fn means(&self) -> &[T] {
852 &self.means
853 }
854
855 pub fn scales(&self) -> &[T] {
857 &self.scales
858 }
859}
860
861impl<T: Float, A: LinearOperator<T>> LinearOperator<T> for StandardizedOperator<A, T> {
862 fn rows(&self) -> usize {
863 self.inner.rows()
864 }
865
866 fn cols(&self) -> usize {
867 self.inner.cols()
868 }
869
870 fn matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
871 validate_standardized_parts(self.cols(), &self.means, &self.scales)?;
872 if x.len() != self.cols() {
873 return Err(Error::shape(vec![self.cols()], vec![x.len()]));
874 }
875 let scaled = x
876 .iter()
877 .zip(&self.scales)
878 .map(|(&value, &scale)| value / scale)
879 .collect::<Vec<_>>();
880 self.inner.matvec(&scaled, y)?;
881 let correction: T = self
882 .means
883 .iter()
884 .zip(&scaled)
885 .map(|(&mean, &xj)| mean * xj)
886 .sum();
887 for yi in y {
888 *yi -= correction;
889 }
890 Ok(())
891 }
892
893 fn t_matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
894 validate_standardized_parts(self.cols(), &self.means, &self.scales)?;
895 self.inner.t_matvec(x, y)?;
896 let total: T = x.iter().copied().sum();
897 for ((yj, &mean), &scale) in y.iter_mut().zip(&self.means).zip(&self.scales) {
898 *yj = (*yj - mean * total) / scale;
899 }
900 Ok(())
901 }
902}
903
904fn validate_standardized_parts<T: Float>(cols: usize, means: &[T], scales: &[T]) -> Result<()> {
905 if means.len() != cols {
906 return Err(Error::shape(vec![cols], vec![means.len()]));
907 }
908 if scales.len() != cols {
909 return Err(Error::shape(vec![cols], vec![scales.len()]));
910 }
911 if scales.iter().any(|&scale| scale == T::zero()) {
912 return Err(Error::NumericalFailure("standardization scale is zero"));
913 }
914 Ok(())
915}
916
917#[derive(Clone, Debug)]
919pub struct RandomizedSvdOptions {
920 pub rank: usize,
922 pub oversampling: usize,
924 pub power_iterations: usize,
926 pub seed: Option<u64>,
928 pub tolerance: Option<f64>,
930 pub compute_u: bool,
932 pub compute_vt: bool,
934}
935
936impl Default for RandomizedSvdOptions {
937 fn default() -> Self {
938 Self {
939 rank: 2,
940 oversampling: 8,
941 power_iterations: 1,
942 seed: None,
943 tolerance: None,
944 compute_u: true,
945 compute_vt: true,
946 }
947 }
948}
949
950#[derive(Clone, Debug, PartialEq)]
952pub struct SvdResult<T> {
953 pub u: Array2<T>,
955 pub s: Vec<T>,
957 pub vt: Array2<T>,
959}
960
961#[derive(Clone, Debug, PartialEq)]
963pub struct EighResult<T> {
964 pub eigenvalues: Vec<T>,
966 pub eigenvectors: Array2<T>,
968}
969
970#[derive(Clone, Debug, PartialEq)]
972pub struct QrResult<T> {
973 pub q: Array2<T>,
975 pub r: Array2<T>,
977}
978
979pub fn qr<T: Float>(a: ArrayView2<'_, T>) -> Result<QrResult<T>> {
981 let mut columns: Vec<Vec<T>> = Vec::new();
982 let mut r_rows: Vec<Vec<T>> = Vec::new();
983 for j in 0..a.cols() {
984 let mut column = (0..a.rows()).map(|i| a[(i, j)]).collect::<Vec<_>>();
985 for _ in 0..2 {
986 for (idx, prev) in columns.iter().enumerate() {
987 let mut projection = T::zero();
988 for i in 0..a.rows() {
989 projection += prev[i] * column[i];
990 }
991 r_rows[idx][j] += projection;
992 for i in 0..a.rows() {
993 column[i] -= projection * prev[i];
994 }
995 }
996 }
997 let mut norm = T::zero();
998 for &value in &column {
999 norm += value * value;
1000 }
1001 norm = norm.sqrt();
1002 if norm <= T::from_f64(1e-12) {
1003 continue;
1004 }
1005 if !norm.is_finite() {
1006 return Err(Error::NumericalFailure("non-finite QR column norm"));
1007 }
1008 for value in &mut column {
1009 *value /= norm;
1010 }
1011 columns.push(column);
1012 let mut r_row = vec![T::zero(); a.cols()];
1013 r_row[j] = norm;
1014 r_rows.push(r_row);
1015 }
1016 if columns.is_empty() && a.cols() > 0 {
1017 return Err(Error::NumericalFailure(
1018 "matrix has no independent QR columns",
1019 ));
1020 }
1021 let q = Array2::from_fn([a.rows(), columns.len()], |i, j| columns[j][i]);
1022 let r = Array2::from_fn([r_rows.len(), a.cols()], |i, j| r_rows[i][j]);
1023 Ok(QrResult { q, r })
1024}
1025
1026pub fn thin_qr<T: Float>(a: ArrayView2<'_, T>) -> Result<Array2<T>> {
1028 Ok(qr(a)?.q)
1029}
1030
1031pub fn reorthogonalize<T: Float>(q: ArrayView2<'_, T>) -> Result<Array2<T>> {
1033 thin_qr(q)
1034}
1035
1036pub fn randomized_range_finder<T: Float, A: LinearOperator<T>>(
1038 a: &A,
1039 rank: usize,
1040 oversampling: usize,
1041 power_iterations: usize,
1042 seed: Option<u64>,
1043) -> Result<Array2<T>> {
1044 let l = (rank + oversampling).min(a.cols()).min(a.rows());
1045 if rank == 0 || rank > a.rows().min(a.cols()) {
1046 return Err(Error::RankTooLarge {
1047 requested: rank,
1048 max: a.rows().min(a.cols()),
1049 });
1050 }
1051 let mut rng = SmallRng::new(seed.unwrap_or(0x5eed_1234_abcd_9876));
1052 let omega = Array2::from_fn([a.cols(), l], |_, _| rng.normal::<T>());
1053 let mut y = Array2::zeros([a.rows(), l]);
1054 a.matmat(omega.view(), y.view_mut())?;
1055
1056 for _ in 0..power_iterations {
1057 let q = thin_qr(y.view())?;
1058 let mut z = Array2::zeros([a.cols(), q.cols()]);
1059 a.t_matmat(q.view(), z.view_mut())?;
1060 y = Array2::zeros([a.rows(), z.cols()]);
1061 a.matmat(z.view(), y.view_mut())?;
1062 }
1063
1064 thin_qr(y.view())
1065}
1066
1067pub fn randomized_svd<T: Float, A: LinearOperator<T>>(
1069 a: &A,
1070 options: RandomizedSvdOptions,
1071) -> Result<SvdResult<T>> {
1072 if options.rank == 0 || options.rank > a.rows().min(a.cols()) {
1073 return Err(Error::RankTooLarge {
1074 requested: options.rank,
1075 max: a.rows().min(a.cols()),
1076 });
1077 }
1078 let q = randomized_range_finder(
1079 a,
1080 options.rank,
1081 options.oversampling,
1082 options.power_iterations,
1083 options.seed,
1084 )?;
1085 let mut at_q = Array2::zeros([a.cols(), q.cols()]);
1086 a.t_matmat(q.view(), at_q.view_mut())?;
1087 let b = Array2::clone_contiguous(at_q.transpose_view());
1088 let small = svd_small(b.view())?;
1089 let rank = options.rank.min(small.s.len());
1090 let u = if options.compute_u {
1091 let projected = matmul(q.view(), small.u.view())?;
1092 Array2::from_fn([a.rows(), rank], |i, j| projected[(i, j)])
1093 } else {
1094 Array2::zeros([0, 0])
1095 };
1096 let s = small.s.into_iter().take(rank).collect();
1097 let vt = if options.compute_vt {
1098 Array2::from_fn([rank, a.cols()], |i, j| small.vt[(i, j)])
1099 } else {
1100 Array2::zeros([0, 0])
1101 };
1102 Ok(SvdResult { u, s, vt })
1103}
1104
1105pub fn randomized_svd_with_error<T: Float>(
1112 a: ArrayView2<'_, T>,
1113 options: RandomizedSvdOptions,
1114) -> Result<(SvdResult<T>, T)> {
1115 let compute_u = options.compute_u;
1116 let compute_vt = options.compute_vt;
1117 let work_options = RandomizedSvdOptions {
1118 compute_u: true,
1119 compute_vt: true,
1120 ..options.clone()
1121 };
1122 let mut result = randomized_svd(&a, work_options)?;
1123 let error = approx_reconstruction_error(a, result.u.view(), &result.s, result.vt.view())?;
1124 if let Some(tolerance) = options.tolerance
1125 && error.to_f64() > tolerance
1126 {
1127 return Err(Error::NotConverged);
1128 }
1129 if !compute_u {
1130 result.u = Array2::zeros([0, 0]);
1131 }
1132 if !compute_vt {
1133 result.vt = Array2::zeros([0, 0]);
1134 }
1135 Ok((result, error))
1136}
1137
1138pub fn batch_randomized_svd<T: Float>(
1140 a: ArrayView3<'_, T>,
1141 axis: Axis3,
1142 options: RandomizedSvdOptions,
1143) -> Result<Vec<SvdResult<T>>> {
1144 let axis_index = axis.index();
1145 let mut results = Vec::with_capacity(a.shape()[axis_index]);
1146 for index in 0..a.shape()[axis_index] {
1147 let matrix = a.matrix_at(axis_index, index)?;
1148 results.push(randomized_svd(&matrix, options.clone())?);
1149 }
1150 Ok(results)
1151}
1152
1153pub fn batch_randomized_svd_parallel<T: Float>(
1157 a: ArrayView3<'_, T>,
1158 axis: Axis3,
1159 options: RandomizedSvdOptions,
1160) -> Result<Vec<SvdResult<T>>> {
1161 let axis_index = axis.index();
1162 (0..a.shape()[axis_index])
1163 .into_par_iter()
1164 .map(|index| {
1165 let matrix = a.matrix_at(axis_index, index)?;
1166 randomized_svd(&matrix, options.clone())
1167 })
1168 .collect()
1169}
1170
1171pub fn approx_reconstruction_error<T: Float>(
1173 a: ArrayView2<'_, T>,
1174 u: ArrayView2<'_, T>,
1175 s: &[T],
1176 vt: ArrayView2<'_, T>,
1177) -> Result<T> {
1178 if u.rows() != a.rows() || u.cols() != s.len() {
1179 return Err(Error::shape(vec![a.rows(), s.len()], u.shape()));
1180 }
1181 if vt.rows() != s.len() || vt.cols() != a.cols() {
1182 return Err(Error::shape(vec![s.len(), a.cols()], vt.shape()));
1183 }
1184
1185 let mut residual = T::zero();
1186 for i in 0..a.rows() {
1187 for j in 0..a.cols() {
1188 let mut approx = T::zero();
1189 for r in 0..s.len() {
1190 approx += u[(i, r)] * s[r] * vt[(r, j)];
1191 }
1192 let diff = a[(i, j)] - approx;
1193 residual += diff * diff;
1194 }
1195 }
1196 Ok(residual.sqrt())
1197}
1198
1199pub fn explained_variance_ratio<T: Float>(s: &[T]) -> Vec<T> {
1201 let total: T = s.iter().copied().map(|value| value * value).sum();
1202 if total == T::zero() {
1203 return vec![T::zero(); s.len()];
1204 }
1205 s.iter()
1206 .copied()
1207 .map(|value| value * value / total)
1208 .collect()
1209}
1210
1211pub fn eigh_small<T: Float>(a: ArrayView2<'_, T>) -> Result<EighResult<T>> {
1213 if a.rows() != a.cols() {
1214 return Err(Error::shape([a.rows(), a.rows()], a.shape()));
1215 }
1216 for i in 0..a.rows() {
1217 for j in (i + 1)..a.cols() {
1218 if (a[(i, j)] - a[(j, i)]).abs() > T::from_f64(1e-9) {
1219 return Err(Error::NumericalFailure("matrix is not symmetric"));
1220 }
1221 }
1222 }
1223
1224 let mut eig = jacobi_symmetric(Array2::from_fn(a.shape(), |i, j| a[(i, j)].to_f64()))?;
1225 eig.sort_by(|left, right| {
1226 right
1227 .0
1228 .partial_cmp(&left.0)
1229 .unwrap_or(core::cmp::Ordering::Equal)
1230 });
1231
1232 let eigenvalues = eig
1233 .iter()
1234 .map(|(value, _)| T::from_f64(*value))
1235 .collect::<Vec<_>>();
1236 let eigenvectors = Array2::from_fn([a.rows(), a.cols()], |i, j| T::from_f64(eig[j].1[i]));
1237 Ok(EighResult {
1238 eigenvalues,
1239 eigenvectors,
1240 })
1241}
1242
1243pub fn svd_small<T: Float>(a: ArrayView2<'_, T>) -> Result<SvdResult<T>> {
1245 let gram = gram_left(a);
1246 let mut eig = jacobi_symmetric(gram)?;
1247 eig.sort_by(|left, right| {
1248 right
1249 .0
1250 .partial_cmp(&left.0)
1251 .unwrap_or(core::cmp::Ordering::Equal)
1252 });
1253
1254 let rank = eig.len().min(a.rows()).min(a.cols());
1255 let mut u = Array2::zeros([a.rows(), rank]);
1256 let mut s = vec![T::zero(); rank];
1257 for j in 0..rank {
1258 let value = eig[j].0.max(0.0).sqrt();
1259 s[j] = T::from_f64(value);
1260 for i in 0..a.rows() {
1261 u[(i, j)] = T::from_f64(eig[j].1[i]);
1262 }
1263 }
1264
1265 let mut vt = Array2::zeros([rank, a.cols()]);
1266 for r in 0..rank {
1267 if s[r] <= T::from_f64(1e-12) {
1268 continue;
1269 }
1270 for col in 0..a.cols() {
1271 let mut value = T::zero();
1272 for row in 0..a.rows() {
1273 value += u[(row, r)] * a[(row, col)];
1274 }
1275 vt[(r, col)] = value / s[r];
1276 }
1277 }
1278 Ok(SvdResult { u, s, vt })
1279}
1280
1281fn gram_left<T: Float>(a: ArrayView2<'_, T>) -> Array2<f64> {
1282 Array2::from_fn([a.rows(), a.rows()], |i, j| {
1283 let mut sum = 0.0;
1284 for col in 0..a.cols() {
1285 sum += a[(i, col)].to_f64() * a[(j, col)].to_f64();
1286 }
1287 sum
1288 })
1289}
1290
1291fn jacobi_symmetric(mut a: Array2<f64>) -> Result<Vec<(f64, Vec<f64>)>> {
1292 if a.rows() != a.cols() {
1293 return Err(Error::shape([a.rows(), a.rows()], a.shape()));
1294 }
1295 let n = a.rows();
1296 let mut v = Array2::from_fn([n, n], |i, j| if i == j { 1.0 } else { 0.0 });
1297 let max_iter = 64usize.saturating_mul(n.max(1)).saturating_mul(n.max(1));
1298
1299 for _ in 0..max_iter {
1300 let mut p = 0;
1301 let mut q = 0;
1302 let mut max = 0.0;
1303 for i in 0..n {
1304 for j in (i + 1)..n {
1305 let value = a[(i, j)].abs();
1306 if value > max {
1307 max = value;
1308 p = i;
1309 q = j;
1310 }
1311 }
1312 }
1313 if max < 1e-12 {
1314 let mut result = Vec::with_capacity(n);
1315 for col in 0..n {
1316 let mut vector = Vec::with_capacity(n);
1317 for row in 0..n {
1318 vector.push(v[(row, col)]);
1319 }
1320 result.push((a[(col, col)], vector));
1321 }
1322 return Ok(result);
1323 }
1324
1325 let app = a[(p, p)];
1326 let aqq = a[(q, q)];
1327 let apq = a[(p, q)];
1328 let tau = (aqq - app) / (2.0 * apq);
1329 let t = tau.signum() / (tau.abs() + (1.0 + tau * tau).sqrt());
1330 let c = 1.0 / (1.0 + t * t).sqrt();
1331 let s = t * c;
1332
1333 for k in 0..n {
1334 if k != p && k != q {
1335 let akp = a[(k, p)];
1336 let akq = a[(k, q)];
1337 let new_kp = c * akp - s * akq;
1338 let new_kq = s * akp + c * akq;
1339 a[(k, p)] = new_kp;
1340 a[(p, k)] = new_kp;
1341 a[(k, q)] = new_kq;
1342 a[(q, k)] = new_kq;
1343 }
1344 }
1345
1346 a[(p, p)] = c * c * app - 2.0 * s * c * apq + s * s * aqq;
1347 a[(q, q)] = s * s * app + 2.0 * s * c * apq + c * c * aqq;
1348 a[(p, q)] = 0.0;
1349 a[(q, p)] = 0.0;
1350
1351 for k in 0..n {
1352 let vkp = v[(k, p)];
1353 let vkq = v[(k, q)];
1354 v[(k, p)] = c * vkp - s * vkq;
1355 v[(k, q)] = s * vkp + c * vkq;
1356 }
1357 }
1358
1359 Err(Error::NotConverged)
1360}