1use scirs2_core::random::{rngs::StdRng, Rng, SeedableRng};
6use scirs2_core::RngExt;
7
8pub mod extensions;
9pub use extensions::*;
10
11pub mod advanced;
12pub use advanced::*;
13
14#[cfg(test)]
15mod tests;
16
17#[derive(Debug, Clone, PartialEq)]
21pub enum SparsityMeasure {
22 L0(usize),
24 L1(f32),
26 ElasticNet(f32, f32),
28}
29
30#[derive(Debug, Clone)]
34pub struct SparseCode {
35 pub coefficients: Vec<f32>,
37 pub support: Vec<usize>,
39}
40
41impl SparseCode {
42 pub fn nnz(&self) -> usize {
44 self.coefficients.iter().filter(|&&c| c != 0.0).count()
45 }
46
47 pub fn reconstruct(&self, dictionary: &[Vec<f32>]) -> Vec<f32> {
49 if dictionary.is_empty() || self.support.is_empty() {
50 return Vec::new();
51 }
52 let dim = dictionary[0].len();
53 let mut out = vec![0.0_f32; dim];
54 for (&idx, &coef) in self.support.iter().zip(self.coefficients.iter()) {
55 if idx < dictionary.len() {
56 for (o, &a) in out.iter_mut().zip(dictionary[idx].iter()) {
57 *o += coef * a;
58 }
59 }
60 }
61 out
62 }
63}
64
65pub struct OrthogonalMatchingPursuit {
69 pub n_nonzero: usize,
71}
72
73impl OrthogonalMatchingPursuit {
74 pub fn encode(&self, signal: &[f32], dictionary: &[Vec<f32>]) -> SparseCode {
80 let n_atoms = dictionary.len();
81 let k = self.n_nonzero.min(n_atoms);
82 let mut residual = signal.to_vec();
83 let mut support: Vec<usize> = Vec::with_capacity(k);
84
85 for _ in 0..k {
86 let mut best_idx = 0;
88 let mut best_corr = f32::NEG_INFINITY;
89 for (j, atom) in dictionary.iter().enumerate() {
90 if support.contains(&j) {
91 continue;
92 }
93 let corr: f32 = dot(&residual, atom).abs();
94 if corr > best_corr {
95 best_corr = corr;
96 best_idx = j;
97 }
98 }
99 support.push(best_idx);
100
101 let coefficients = omp_least_squares(signal, dictionary, &support);
103
104 residual = signal.to_vec();
106 for (&idx, &c) in support.iter().zip(coefficients.iter()) {
107 for (r, &a) in residual.iter_mut().zip(dictionary[idx].iter()) {
108 *r -= c * a;
109 }
110 }
111
112 if norm_sq(&residual) < 1e-10 {
114 let coefs = omp_least_squares(signal, dictionary, &support);
115 return SparseCode {
116 coefficients: coefs,
117 support,
118 };
119 }
120 }
121
122 let coefficients = omp_least_squares(signal, dictionary, &support);
123 SparseCode {
124 coefficients,
125 support,
126 }
127 }
128}
129
130pub(crate) fn omp_least_squares(
132 signal: &[f32],
133 dictionary: &[Vec<f32>],
134 support: &[usize],
135) -> Vec<f32> {
136 let k = support.len();
137 let ds: Vec<Vec<f32>> = support.iter().map(|&i| dictionary[i].clone()).collect();
138 let mut gram = vec![vec![0.0_f32; k]; k];
139 for i in 0..k {
140 for j in 0..k {
141 gram[i][j] = dot(&ds[i], &ds[j]);
142 }
143 }
144 let rhs: Vec<f32> = ds.iter().map(|col| dot(col, signal)).collect();
145 cholesky_solve(&gram, &rhs).unwrap_or_else(|_| vec![0.0_f32; k])
146}
147
148pub(crate) fn cholesky_solve(a: &[Vec<f32>], b: &[f32]) -> Result<Vec<f32>, ()> {
150 let n = b.len();
151 if n == 0 {
152 return Ok(Vec::new());
153 }
154 let mut m: Vec<Vec<f32>> = a
155 .iter()
156 .zip(b.iter())
157 .map(|(row, &bi)| {
158 let mut r = row.clone();
159 r.push(bi);
160 r
161 })
162 .collect();
163 for col in 0..n {
164 let mut pivot_row = col;
165 let mut max_val = m[col][col].abs();
166 for row in (col + 1)..n {
167 if m[row][col].abs() > max_val {
168 max_val = m[row][col].abs();
169 pivot_row = row;
170 }
171 }
172 if max_val < 1e-12 {
173 return Err(());
174 }
175 m.swap(col, pivot_row);
176 let diag = m[col][col];
177 for v in m[col].iter_mut() {
178 *v /= diag;
179 }
180 for row in 0..n {
181 if row == col {
182 continue;
183 }
184 let factor = m[row][col];
185 let pivot_row_copy = m[col].clone();
186 for (v, &p) in m[row].iter_mut().zip(pivot_row_copy.iter()) {
187 *v -= factor * p;
188 }
189 }
190 }
191 Ok(m.iter()
192 .map(|row| *row.last().unwrap_or(&0.0))
193 .collect())
194}
195
196pub struct LassoEncoder {
200 pub lambda: f32,
202 pub max_iter: usize,
204 pub tol: f32,
206}
207
208impl LassoEncoder {
209 pub fn encode_ista(&self, signal: &[f32], dictionary: &[Vec<f32>]) -> SparseCode {
214 let n_atoms = dictionary.len();
215 if n_atoms == 0 {
216 return SparseCode {
217 coefficients: Vec::new(),
218 support: Vec::new(),
219 };
220 }
221
222 let lipschitz = estimate_lipschitz(dictionary);
224 let step = if lipschitz > 1e-10 {
225 1.0 / lipschitz
226 } else {
227 0.01
228 };
229 let threshold = self.lambda * step;
230
231 let mut x = vec![0.0_f32; n_atoms];
232 for _ in 0..self.max_iter {
233 let x_old = x.clone();
234 let mut residual = signal.to_vec();
236 for (j, atom) in dictionary.iter().enumerate() {
237 let coef = x[j];
238 for (r, &a) in residual.iter_mut().zip(atom.iter()) {
239 *r -= coef * a;
240 }
241 }
242 for (j, atom) in dictionary.iter().enumerate() {
244 x[j] += dot(atom, &residual) * step;
245 }
246 for v in x.iter_mut() {
248 *v = Self::soft_threshold(*v, threshold);
249 }
250 let change: f32 = x
252 .iter()
253 .zip(x_old.iter())
254 .map(|(a, b)| (a - b).powi(2))
255 .sum::<f32>()
256 .sqrt();
257 if change < self.tol {
258 break;
259 }
260 }
261
262 let support: Vec<usize> = x
263 .iter()
264 .enumerate()
265 .filter(|(_, &v)| v != 0.0)
266 .map(|(i, _)| i)
267 .collect();
268 let coefficients: Vec<f32> = support.iter().map(|&i| x[i]).collect();
269 SparseCode {
270 coefficients,
271 support,
272 }
273 }
274
275 #[inline]
277 pub fn soft_threshold(x: f32, threshold: f32) -> f32 {
278 if x > threshold {
279 x - threshold
280 } else if x < -threshold {
281 x + threshold
282 } else {
283 0.0
284 }
285 }
286}
287
288pub(crate) fn estimate_lipschitz(dictionary: &[Vec<f32>]) -> f32 {
290 let n = dictionary.len();
291 if n == 0 {
292 return 1.0;
293 }
294 let mut v = vec![1.0_f32 / (n as f32).sqrt(); n];
295 for _ in 0..20 {
296 let mut w = vec![0.0_f32; n];
298 for (i, atom_i) in dictionary.iter().enumerate() {
299 for (j, atom_j) in dictionary.iter().enumerate() {
300 w[i] += dot(atom_i, atom_j) * v[j];
301 }
302 }
303 let nrm = norm(&w);
304 if nrm < 1e-12 {
305 return 1.0;
306 }
307 v = w.iter().map(|&x| x / nrm).collect();
308 }
309 let mut w = vec![0.0_f32; n];
311 for (i, atom_i) in dictionary.iter().enumerate() {
312 for (j, atom_j) in dictionary.iter().enumerate() {
313 w[i] += dot(atom_i, atom_j) * v[j];
314 }
315 }
316 dot(&w, &v).max(1e-10)
317}
318
319#[derive(Debug, Clone)]
323pub struct DlConfig {
324 pub n_atoms: usize,
326 pub n_nonzero: usize,
328 pub max_iter: usize,
330 pub tol: f32,
332}
333
334#[derive(Debug, Clone)]
336pub struct Dictionary {
337 pub atoms: Vec<Vec<f32>>,
339 pub n_atoms: usize,
341 pub atom_dim: usize,
343}
344
345impl Dictionary {
346 pub fn new(atoms: Vec<Vec<f32>>) -> Self {
348 let n_atoms = atoms.len();
349 let atom_dim = atoms.first().map(|a| a.len()).unwrap_or(0);
350 Self {
351 atoms,
352 n_atoms,
353 atom_dim,
354 }
355 }
356
357 pub fn atom(&self, i: usize) -> &[f32] {
359 &self.atoms[i]
360 }
361
362 pub fn normalize_atoms(&mut self) {
364 for atom in self.atoms.iter_mut() {
365 let n = norm(atom);
366 if n > 1e-10 {
367 for v in atom.iter_mut() {
368 *v /= n;
369 }
370 }
371 }
372 }
373
374 pub fn coherence(&self) -> f32 {
376 let n = self.atoms.len();
377 let mut max_corr = 0.0_f32;
378 for i in 0..n {
379 for j in (i + 1)..n {
380 let c = dot(&self.atoms[i], &self.atoms[j]).abs();
381 if c > max_corr {
382 max_corr = c;
383 }
384 }
385 }
386 max_corr
387 }
388}
389
390pub struct KSvd {
394 pub config: DlConfig,
396}
397
398impl KSvd {
399 pub fn fit(&self, data: &[Vec<f32>]) -> (Dictionary, Vec<SparseCode>) {
405 if data.is_empty() {
406 let d = Dictionary {
407 atoms: Vec::new(),
408 n_atoms: 0,
409 atom_dim: 0,
410 };
411 return (d, Vec::new());
412 }
413 let signal_dim = data[0].len();
414 let n_atoms = self.config.n_atoms;
415 let n_samples = data.len();
416
417 let mut rng = StdRng::seed_from_u64(42);
419 let mut atoms: Vec<Vec<f32>> = (0..n_atoms)
420 .map(|i| {
421 let sample = &data[i % n_samples];
422 let mut atom = sample.clone();
423 for v in atom.iter_mut() {
425 *v += (rng.random::<f32>() - 0.5) * 1e-4;
426 }
427 atom
428 })
429 .collect();
430 for atom in atoms.iter_mut() {
432 let n = norm(atom);
433 if n > 1e-10 {
434 for v in atom.iter_mut() {
435 *v /= n;
436 }
437 }
438 }
439
440 let omp = OrthogonalMatchingPursuit {
441 n_nonzero: self.config.n_nonzero,
442 };
443 let mut codes: Vec<SparseCode> = data.iter().map(|s| omp.encode(s, &atoms)).collect();
444
445 for _iter in 0..self.config.max_iter {
446 let old_atoms = atoms.clone();
447
448 for k in 0..n_atoms {
450 let using: Vec<usize> = (0..n_samples)
452 .filter(|&i| codes[i].support.contains(&k))
453 .collect();
454 if using.is_empty() {
455 continue;
456 }
457
458 let e_k: Vec<Vec<f32>> = using
460 .iter()
461 .map(|&i| {
462 let mut e = data[i].clone();
463 for (&sup_idx, &coef) in
464 codes[i].support.iter().zip(codes[i].coefficients.iter())
465 {
466 if sup_idx == k {
467 continue;
468 }
469 for (ev, &av) in e.iter_mut().zip(atoms[sup_idx].iter()) {
470 *ev -= coef * av;
471 }
472 }
473 e
474 })
475 .collect();
476 let (u, sigma, v) = rank1_svd(&e_k, signal_dim, 20);
477 atoms[k] = u;
478 for (sample_pos, &sample_idx) in using.iter().enumerate() {
479 let new_coef = sigma * v[sample_pos];
480 if let Some(pos) = codes[sample_idx].support.iter().position(|&s| s == k) {
482 codes[sample_idx].coefficients[pos] = new_coef;
483 }
484 }
485 }
486
487 let change: f32 = atoms
489 .iter()
490 .zip(old_atoms.iter())
491 .map(|(a, b)| {
492 a.iter()
493 .zip(b.iter())
494 .map(|(x, y)| (x - y).powi(2))
495 .sum::<f32>()
496 })
497 .sum::<f32>()
498 .sqrt();
499 if change < self.config.tol {
500 break;
501 }
502
503 codes = data.iter().map(|s| omp.encode(s, &atoms)).collect();
505 }
506
507 let dict = Dictionary::new(atoms);
508 (dict, codes)
509 }
510}
511
512pub(crate) fn rank1_svd(
514 matrix: &[Vec<f32>],
515 _signal_dim: usize,
516 max_iter: usize,
517) -> (Vec<f32>, f32, Vec<f32>) {
518 let m = matrix.len();
519 if m == 0 {
520 return (Vec::new(), 0.0, Vec::new());
521 }
522 let n = matrix[0].len();
523 let mut v: Vec<f32> = (0..m).map(|i| if i == 0 { 1.0 } else { 0.0 }).collect();
524 let mut u = vec![0.0_f32; n];
525 for _ in 0..max_iter {
526 for i in 0..n {
527 u[i] = matrix
528 .iter()
529 .zip(v.iter())
530 .map(|(row, &vj)| row[i] * vj)
531 .sum();
532 }
533 let sigma = norm(&u);
534 if sigma < 1e-12 {
535 return (vec![0.0; n], 0.0, vec![0.0; m]);
536 }
537 for x in u.iter_mut() {
538 *x /= sigma;
539 }
540
541 for j in 0..m {
543 v[j] = u
544 .iter()
545 .zip(matrix[j].iter())
546 .map(|(&ui, &aij)| ui * aij)
547 .sum();
548 }
549 let vnorm = norm(&v);
550 if vnorm < 1e-12 {
551 break;
552 }
553 for x in v.iter_mut() {
554 *x /= vnorm;
555 }
556 }
557
558 let mut mv = vec![0.0_f32; n];
559 for i in 0..n {
560 mv[i] = matrix
561 .iter()
562 .zip(v.iter())
563 .map(|(row, &vj)| row[i] * vj)
564 .sum();
565 }
566 let sigma = dot(&u, &mv);
567 (u, sigma, v)
568}
569
570pub struct OnlineDictionaryLearning {
574 pub config: DlConfig,
576 pub a_matrix: Vec<Vec<f32>>,
578 pub b_matrix: Vec<Vec<f32>>,
580}
581
582impl OnlineDictionaryLearning {
583 pub fn new(config: DlConfig, signal_dim: usize) -> Self {
585 let n_atoms = config.n_atoms;
586 let a_matrix = vec![vec![0.0_f32; n_atoms]; n_atoms];
587 let b_matrix = vec![vec![0.0_f32; n_atoms]; signal_dim];
588 Self {
589 config,
590 a_matrix,
591 b_matrix,
592 }
593 }
594
595 pub fn update(&mut self, sample: &[f32], current_dict: &Dictionary) -> Dictionary {
597 let omp = OrthogonalMatchingPursuit {
598 n_nonzero: self.config.n_nonzero,
599 };
600 let code = omp.encode(sample, ¤t_dict.atoms);
601
602 let n_atoms = self.config.n_atoms;
603 let signal_dim = sample.len();
604
605 let mut alpha = vec![0.0_f32; n_atoms];
607 for (&idx, &coef) in code.support.iter().zip(code.coefficients.iter()) {
608 if idx < n_atoms {
609 alpha[idx] = coef;
610 }
611 }
612
613 for i in 0..n_atoms {
615 for j in 0..n_atoms {
616 self.a_matrix[i][j] += alpha[i] * alpha[j];
617 }
618 }
619 for i in 0..signal_dim.min(self.b_matrix.len()) {
621 for j in 0..n_atoms {
622 self.b_matrix[i][j] += sample[i] * alpha[j];
623 }
624 }
625
626 let mut new_atoms = current_dict.atoms.clone();
628 for k in 0..n_atoms {
629 let a_kk = self.a_matrix[k][k];
630 if a_kk.abs() < 1e-10 {
631 continue;
632 }
633
634 let mut u_k = vec![0.0_f32; signal_dim];
636 for i in 0..signal_dim.min(self.b_matrix.len()) {
637 u_k[i] = self.b_matrix[i][k];
638 }
639 for j in 0..n_atoms {
640 if j == k {
641 continue;
642 }
643 let a_jk = self.a_matrix[j][k];
644 for i in 0..signal_dim.min(new_atoms[j].len()) {
645 u_k[i] -= new_atoms[j][i] * a_jk;
646 }
647 }
648 for v in u_k.iter_mut() {
649 *v /= a_kk;
650 }
651
652 let n = norm(&u_k);
654 if n > 1e-10 {
655 for v in u_k.iter_mut() {
656 *v /= n;
657 }
658 }
659 new_atoms[k] = u_k;
660 }
661
662 Dictionary::new(new_atoms)
663 }
664
665 pub fn fit_stream(&mut self, data: &[Vec<f32>]) -> Dictionary {
667 if data.is_empty() {
668 return Dictionary {
669 atoms: Vec::new(),
670 n_atoms: 0,
671 atom_dim: 0,
672 };
673 }
674 let signal_dim = data[0].len();
675 let n_atoms = self.config.n_atoms;
676
677 let mut rng = StdRng::seed_from_u64(123);
679 let mut atoms: Vec<Vec<f32>> = (0..n_atoms)
680 .map(|i| {
681 let mut atom = data[i % data.len()].clone();
682 for v in atom.iter_mut() {
683 *v += (rng.random::<f32>() - 0.5) * 1e-3;
684 }
685 atom
686 })
687 .collect();
688 for atom in atoms.iter_mut() {
689 let n = norm(atom);
690 if n > 1e-10 {
691 for v in atom.iter_mut() {
692 *v /= n;
693 }
694 }
695 }
696
697 let mut dict = Dictionary::new(atoms);
698 for sample in data.iter() {
699 dict = self.update(sample, &dict);
700 }
701 dict
702 }
703}
704
705#[derive(Debug, Clone)]
709pub enum MeasurementMatrix {
710 Gaussian(usize, usize),
712 Bernoulli(usize, usize),
714 Dct(usize),
716}
717
718#[derive(Debug, Clone)]
720pub struct CsMeasurement {
721 pub y: Vec<f32>,
723 pub n_measurements: usize,
725 pub signal_len: usize,
727}
728
729#[derive(Debug, Clone)]
731pub struct CsMatrix {
732 pub rows: Vec<Vec<f32>>,
734}
735
736impl CsMatrix {
737 pub fn nrows(&self) -> usize {
739 self.rows.len()
740 }
741 pub fn ncols(&self) -> usize {
743 self.rows.first().map(|r| r.len()).unwrap_or(0)
744 }
745
746 pub(crate) fn transpose_mul(&self, v: &[f32]) -> Vec<f32> {
748 let ncols = self.ncols();
749 let mut result = vec![0.0_f32; ncols];
750 for (row, &vi) in self.rows.iter().zip(v.iter()) {
751 for (r, &a) in result.iter_mut().zip(row.iter()) {
752 *r += a * vi;
753 }
754 }
755 result
756 }
757
758 pub(crate) fn forward_mul(&self, v: &[f32]) -> Vec<f32> {
760 self.rows.iter().map(|row| dot(row, v)).collect()
761 }
762}
763
764pub fn generate_matrix(m: &MeasurementMatrix, rng: &mut impl Rng) -> CsMatrix {
766 match m {
767 MeasurementMatrix::Gaussian(rows, cols) => {
768 let scale = 1.0 / (*rows as f32).sqrt();
769 let rows_data: Vec<Vec<f32>> = (0..*rows)
770 .map(|_| (0..*cols).map(|_| box_muller(rng) * scale).collect())
771 .collect();
772 CsMatrix { rows: rows_data }
773 }
774 MeasurementMatrix::Bernoulli(rows, cols) => {
775 let scale = 1.0 / (*rows as f32).sqrt();
776 let rows_data: Vec<Vec<f32>> = (0..*rows)
777 .map(|_| {
778 (0..*cols)
779 .map(|_| {
780 if rng.random::<f32>() > 0.5 {
781 scale
782 } else {
783 -scale
784 }
785 })
786 .collect()
787 })
788 .collect();
789 CsMatrix { rows: rows_data }
790 }
791 MeasurementMatrix::Dct(m_rows) => {
792 let n = (2 * m_rows).max(4);
794 let dct_rows: Vec<Vec<f32>> = (0..*m_rows)
795 .map(|k| {
796 (0..n)
797 .map(|j| {
798 let scale = if k == 0 {
799 (1.0 / n as f32).sqrt()
800 } else {
801 (2.0 / n as f32).sqrt()
802 };
803 let angle = std::f32::consts::PI * k as f32 * (2 * j + 1) as f32
804 / (2 * n) as f32;
805 scale * angle.cos()
806 })
807 .collect()
808 })
809 .collect();
810 CsMatrix { rows: dct_rows }
811 }
812 }
813}
814
815pub(crate) fn box_muller(rng: &mut impl Rng) -> f32 {
817 let u1 = (rng.random::<f32>()).max(1e-30);
818 let u2 = rng.random::<f32>();
819 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos()
820}
821
822pub fn measure(signal: &[f32], matrix: &CsMatrix) -> CsMeasurement {
824 let signal_len = signal.len();
825 let y = matrix.forward_mul(signal);
826 let n_measurements = y.len();
827 CsMeasurement {
828 y,
829 n_measurements,
830 signal_len,
831 }
832}
833
834pub struct BasisPursuit {
838 pub max_iter: usize,
840 pub rho: f32,
842 pub tol: f32,
844}
845
846impl BasisPursuit {
847 pub fn recover(&self, measurement: &CsMeasurement, matrix: &CsMatrix) -> Vec<f32> {
849 let n = measurement.signal_len;
850 let rho = self.rho;
851
852 let mut x = vec![0.0_f32; n];
854 let mut z = vec![0.0_f32; n];
855 let mut u = vec![0.0_f32; n]; let aty: Vec<f32> = matrix.transpose_mul(&measurement.y);
859
860 for _ in 0..self.max_iter {
861 let rhs: Vec<f32> = (0..n).map(|i| aty[i] + rho * (z[i] - u[i])).collect();
864 x = admm_cg(matrix, rho, &rhs, &x, 50);
865
866 let z_old = z.clone();
868 for i in 0..n {
869 z[i] = LassoEncoder::soft_threshold(x[i] + u[i], 1.0 / rho);
870 }
871
872 for i in 0..n {
874 u[i] += x[i] - z[i];
875 }
876
877 let primal: f32 = (0..n).map(|i| (x[i] - z[i]).powi(2)).sum::<f32>().sqrt();
879 let dual: f32 = (0..n)
880 .map(|i| (rho * (z[i] - z_old[i])).powi(2))
881 .sum::<f32>()
882 .sqrt();
883 if primal < self.tol && dual < self.tol {
884 break;
885 }
886 }
887 z
888 }
889}
890
891fn admm_cg(matrix: &CsMatrix, rho: f32, b: &[f32], x0: &[f32], max_iter: usize) -> Vec<f32> {
893 let n = b.len();
894 let mut x = x0.to_vec();
895 let ax = matrix.forward_mul(&x);
897 let atax = matrix.transpose_mul(&ax);
898 let mut r: Vec<f32> = (0..n).map(|i| b[i] - atax[i] - rho * x[i]).collect();
899 let mut p = r.clone();
900 let mut rsold: f32 = r.iter().map(|&v| v * v).sum();
901
902 for _ in 0..max_iter {
903 if rsold < 1e-10 {
904 break;
905 }
906 let ap_full = matrix.forward_mul(&p);
907 let atp = matrix.transpose_mul(&ap_full);
908 let ap: Vec<f32> = (0..n).map(|i| atp[i] + rho * p[i]).collect();
909 let denom: f32 = p.iter().zip(ap.iter()).map(|(&pi, &api)| pi * api).sum();
910 if denom.abs() < 1e-14 {
911 break;
912 }
913 let alpha = rsold / denom;
914 for i in 0..n {
915 x[i] += alpha * p[i];
916 r[i] -= alpha * ap[i];
917 }
918 let rsnew: f32 = r.iter().map(|&v| v * v).sum();
919 let beta = rsnew / rsold.max(1e-14);
920 for i in 0..n {
921 p[i] = r[i] + beta * p[i];
922 }
923 rsold = rsnew;
924 }
925 x
926}
927
928pub struct CoSaMP {
932 pub n_nonzero: usize,
934 pub max_iter: usize,
936}
937
938impl CoSaMP {
939 pub fn recover(&self, measurement: &CsMeasurement, matrix: &CsMatrix) -> Vec<f32> {
941 let n = measurement.signal_len;
942 let s = self.n_nonzero;
943
944 let mut x = vec![0.0_f32; n];
945
946 for _ in 0..self.max_iter {
947 let ax = matrix.forward_mul(&x);
949 let residual: Vec<f32> = measurement
950 .y
951 .iter()
952 .zip(ax.iter())
953 .map(|(&y, &a)| y - a)
954 .collect();
955 let proxy = matrix.transpose_mul(&residual);
956
957 let mut indices: Vec<usize> = (0..n).collect();
959 indices.sort_by(|&a, &b| {
960 proxy[b]
961 .abs()
962 .partial_cmp(&proxy[a].abs())
963 .unwrap_or(std::cmp::Ordering::Equal)
964 });
965 let mut support: Vec<usize> = indices[..((2 * s).min(n))].to_vec();
966
967 let current_nonzero: Vec<usize> = x
969 .iter()
970 .enumerate()
971 .filter(|(_, &v)| v != 0.0)
972 .map(|(i, _)| i)
973 .collect();
974 for idx in current_nonzero {
975 if !support.contains(&idx) {
976 support.push(idx);
977 }
978 }
979
980 let coefs = cosamp_ls(&measurement.y, matrix, &support);
982
983 let mut coef_pairs: Vec<(usize, f32)> = support
985 .iter()
986 .zip(coefs.iter())
987 .map(|(&i, &c)| (i, c))
988 .collect();
989 coef_pairs.sort_by(|a, b| {
990 b.1.abs()
991 .partial_cmp(&a.1.abs())
992 .unwrap_or(std::cmp::Ordering::Equal)
993 });
994 coef_pairs.truncate(s);
995
996 x = vec![0.0_f32; n];
997 for (i, c) in coef_pairs {
998 x[i] = c;
999 }
1000
1001 let ax_new = matrix.forward_mul(&x);
1003 let res_norm: f32 = measurement
1004 .y
1005 .iter()
1006 .zip(ax_new.iter())
1007 .map(|(&y, &a)| (y - a).powi(2))
1008 .sum::<f32>()
1009 .sqrt();
1010 if res_norm < 1e-6 {
1011 break;
1012 }
1013 }
1014 x
1015 }
1016}
1017
1018fn cosamp_ls(y: &[f32], matrix: &CsMatrix, support: &[usize]) -> Vec<f32> {
1020 let k = support.len();
1021 let m = y.len();
1022 let as_cols: Vec<Vec<f32>> = support
1024 .iter()
1025 .map(|&i| {
1026 matrix
1027 .rows
1028 .iter()
1029 .map(|row| *row.get(i).unwrap_or(&0.0))
1030 .collect()
1031 })
1032 .collect();
1033 let mut gram = vec![vec![0.0_f32; k]; k];
1035 for i in 0..k {
1036 for j in 0..k {
1037 gram[i][j] = (0..m).map(|r| as_cols[i][r] * as_cols[j][r]).sum();
1038 }
1039 }
1040 let rhs: Vec<f32> = (0..k)
1041 .map(|i| (0..m).map(|r| as_cols[i][r] * y[r]).sum())
1042 .collect();
1043 cholesky_solve(&gram, &rhs).unwrap_or_else(|_| vec![0.0_f32; k])
1044}
1045
1046pub fn rip_constant_estimate(
1048 matrix: &CsMatrix,
1049 s: usize,
1050 n_trials: usize,
1051 rng: &mut impl Rng,
1052) -> f32 {
1053 let n = matrix.ncols();
1054 if n == 0 {
1055 return 0.0;
1056 }
1057 let mut max_delta = 0.0_f32;
1058
1059 for _ in 0..n_trials {
1060 let mut support: Vec<usize> = (0..n).collect();
1062 for i in 0..s.min(n) {
1064 let j = i + (rng.random_range(0..(n - i)));
1065 support.swap(i, j);
1066 }
1067 let support = &support[..s.min(n)];
1068
1069 let mut x = vec![0.0_f32; n];
1070 let mut total_sq = 0.0_f32;
1071 for &i in support {
1072 let v = rng.random::<f32>() * 2.0 - 1.0;
1073 x[i] = v;
1074 total_sq += v * v;
1075 }
1076 if total_sq < 1e-10 {
1077 continue;
1078 }
1079 let x_norm_sq = total_sq;
1080 for v in x.iter_mut() {
1081 *v /= x_norm_sq.sqrt();
1082 }
1083
1084 let ax = matrix.forward_mul(&x);
1085 let ax_norm_sq: f32 = ax.iter().map(|&v| v * v).sum();
1086 let delta = (ax_norm_sq - 1.0).abs();
1087 if delta > max_delta {
1088 max_delta = delta;
1089 }
1090 }
1091 max_delta
1092}
1093
1094#[derive(Debug, Clone)]
1098pub struct SaeConfig {
1099 pub input_dim: usize,
1101 pub hidden_dim: usize,
1103 pub sparsity_target: f32,
1105 pub sparsity_weight: f32,
1107}
1108
1109#[derive(Debug, Clone)]
1111pub struct SparseAutoencoder {
1112 pub encoder_w: Vec<Vec<f32>>,
1114 pub encoder_b: Vec<f32>,
1116 pub decoder_w: Vec<Vec<f32>>,
1118 pub decoder_b: Vec<f32>,
1120}
1121
1122impl SparseAutoencoder {
1123 pub fn new(config: &SaeConfig, rng: &mut impl Rng) -> Self {
1125 let scale_enc = (2.0 / (config.input_dim + config.hidden_dim) as f32).sqrt();
1126 let scale_dec = (2.0 / (config.hidden_dim + config.input_dim) as f32).sqrt();
1127
1128 let encoder_w: Vec<Vec<f32>> = (0..config.hidden_dim)
1129 .map(|_| {
1130 (0..config.input_dim)
1131 .map(|_| box_muller(rng) * scale_enc)
1132 .collect()
1133 })
1134 .collect();
1135 let encoder_b = vec![0.0_f32; config.hidden_dim];
1136
1137 let decoder_w: Vec<Vec<f32>> = (0..config.input_dim)
1138 .map(|_| {
1139 (0..config.hidden_dim)
1140 .map(|_| box_muller(rng) * scale_dec)
1141 .collect()
1142 })
1143 .collect();
1144 let decoder_b = vec![0.0_f32; config.input_dim];
1145
1146 Self {
1147 encoder_w,
1148 encoder_b,
1149 decoder_w,
1150 decoder_b,
1151 }
1152 }
1153
1154 pub fn encode(&self, x: &[f32]) -> Vec<f32> {
1156 let hidden_dim = self.encoder_w.len();
1157 let mut h: Vec<f32> = (0..hidden_dim)
1159 .map(|i| {
1160 let pre_act = dot(&self.encoder_w[i], x) + self.encoder_b[i];
1161 pre_act.max(0.0) })
1163 .collect();
1164
1165 let k = ((hidden_dim as f32 * 0.1).ceil() as usize)
1167 .max(1)
1168 .min(hidden_dim);
1169 let mut indexed: Vec<(usize, f32)> = h.iter().cloned().enumerate().collect();
1170 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1171 let threshold_idx = k;
1172 let zero_indices: Vec<usize> = indexed[threshold_idx..].iter().map(|(i, _)| *i).collect();
1173 for i in zero_indices {
1174 h[i] = 0.0;
1175 }
1176 h
1177 }
1178
1179 pub fn decode(&self, z: &[f32]) -> Vec<f32> {
1181 let input_dim = self.decoder_w.len();
1182 (0..input_dim)
1183 .map(|i| dot(&self.decoder_w[i], z) + self.decoder_b[i])
1184 .collect()
1185 }
1186
1187 pub fn reconstruction_loss(&self, x: &[f32]) -> f32 {
1189 let z = self.encode(x);
1190 let x_hat = self.decode(&z);
1191 let n = x.len();
1192 if n == 0 {
1193 return 0.0;
1194 }
1195 x.iter()
1196 .zip(x_hat.iter())
1197 .map(|(&a, &b)| (a - b).powi(2))
1198 .sum::<f32>()
1199 / n as f32
1200 }
1201
1202 pub fn sparsity_loss(&self, z: &[f32]) -> f32 {
1204 if z.is_empty() {
1205 return 0.0;
1206 }
1207 let rho_hat = z.iter().map(|&v| v.clamp(0.0, 1.0)).sum::<f32>() / z.len() as f32;
1208 let rho = 0.05_f32; Self::kl_divergence_bernoulli(rho, rho_hat)
1210 }
1211
1212 pub fn total_loss(&self, x: &[f32], config: &SaeConfig) -> f32 {
1214 let z = self.encode(x);
1215 let rec = self.reconstruction_loss(x);
1216 let spar = self.sparsity_loss_with_target(&z, config.sparsity_target);
1217 rec + config.sparsity_weight * spar
1218 }
1219
1220 fn sparsity_loss_with_target(&self, z: &[f32], target: f32) -> f32 {
1221 if z.is_empty() {
1222 return 0.0;
1223 }
1224 let rho_hat = z.iter().map(|&v| v.clamp(0.0, 1.0)).sum::<f32>() / z.len() as f32;
1225 Self::kl_divergence_bernoulli(target, rho_hat)
1226 }
1227
1228 pub fn kl_divergence_bernoulli(rho: f32, rho_hat: f32) -> f32 {
1230 let eps = 1e-8_f32;
1231 let rho = rho.clamp(eps, 1.0 - eps);
1232 let rho_hat = rho_hat.clamp(eps, 1.0 - eps);
1233 rho * (rho / rho_hat).ln() + (1.0 - rho) * ((1.0 - rho) / (1.0 - rho_hat)).ln()
1234 }
1235}
1236
1237pub fn reconstruction_error(original: &[f32], reconstructed: &[f32]) -> f32 {
1241 if original.is_empty() {
1242 return 0.0;
1243 }
1244 let err: f32 = original
1245 .iter()
1246 .zip(reconstructed.iter())
1247 .map(|(a, b)| (a - b).powi(2))
1248 .sum::<f32>();
1249 let norm_orig: f32 = original.iter().map(|v| v * v).sum::<f32>();
1250 if norm_orig < 1e-10 {
1251 return err.sqrt();
1252 }
1253 (err / norm_orig).sqrt()
1254}
1255
1256pub fn sparsity_ratio(code: &SparseCode, signal_len: usize) -> f32 {
1258 if signal_len == 0 {
1259 return 0.0;
1260 }
1261 code.nnz() as f32 / signal_len as f32
1262}
1263
1264pub fn coherence_bound(n_atoms: usize, signal_dim: usize) -> f32 {
1266 let n = n_atoms as f32;
1267 let d = signal_dim as f32;
1268 if n <= 1.0 || d <= 0.0 || n <= d {
1269 return 0.0;
1270 }
1271 ((n - d) / (d * (n - 1.0))).sqrt()
1272}
1273
1274pub fn recovery_quality(original: &[f32], recovered: &[f32]) -> f32 {
1276 if original.is_empty() {
1277 return 0.0;
1278 }
1279 let mse: f32 = original
1280 .iter()
1281 .zip(recovered.iter())
1282 .map(|(a, b)| (a - b).powi(2))
1283 .sum::<f32>()
1284 / original.len() as f32;
1285 if mse < 1e-14 {
1286 return 100.0;
1287 }
1288 let peak = original
1289 .iter()
1290 .cloned()
1291 .fold(0.0_f32, |acc, v| acc.max(v.abs()));
1292 let peak = if peak < 1e-10 { 1.0 } else { peak };
1293 10.0 * (peak * peak / mse).log10()
1294}
1295
1296#[derive(Debug, Clone)]
1298pub struct SparseLearningReport {
1299 pub reconstruction_error: f32,
1301 pub sparsity: f32,
1303 pub coherence: f32,
1305 pub snr_db: f32,
1307}
1308
1309pub fn evaluate(data: &[Vec<f32>], dict: &Dictionary, n_nonzero: usize) -> SparseLearningReport {
1311 let omp = OrthogonalMatchingPursuit { n_nonzero };
1312 let mut total_rec_err = 0.0_f32;
1313 let mut total_sparsity = 0.0_f32;
1314 let mut total_snr = 0.0_f32;
1315 let n = data.len();
1316
1317 for signal in data.iter() {
1318 let code = omp.encode(signal, &dict.atoms);
1319 let reconstructed = code.reconstruct(&dict.atoms);
1320 let signal_len = signal.len();
1321 total_rec_err += reconstruction_error(signal, &reconstructed);
1322 total_sparsity += sparsity_ratio(&code, signal_len);
1323 total_snr += recovery_quality(signal, &reconstructed);
1324 }
1325
1326 let (rec_err, sparsity, snr) = if n > 0 {
1327 (
1328 total_rec_err / n as f32,
1329 total_sparsity / n as f32,
1330 total_snr / n as f32,
1331 )
1332 } else {
1333 (0.0, 0.0, 0.0)
1334 };
1335
1336 SparseLearningReport {
1337 reconstruction_error: rec_err,
1338 sparsity,
1339 coherence: dict.coherence(),
1340 snr_db: snr,
1341 }
1342}
1343
1344#[inline]
1347pub(crate) fn dot(a: &[f32], b: &[f32]) -> f32 {
1348 a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
1349}
1350
1351#[inline]
1352pub(crate) fn norm(v: &[f32]) -> f32 {
1353 v.iter().map(|&x| x * x).sum::<f32>().sqrt()
1354}
1355
1356#[inline]
1357pub(crate) fn norm_sq(v: &[f32]) -> f32 {
1358 v.iter().map(|&x| x * x).sum()
1359}