Skip to main content

tenflowers_neural/sparse_learning/
mod.rs

1//! Sparse Coding, Dictionary Learning, and Compressed Sensing.
2//! OMP/ISTA encoders, K-SVD/Online dict learning, ADMM/CoSaMP recovery,
3//! sparse autoencoder, matching pursuit, LASSO/ElasticNet/GroupLASSO regression.
4
5use scirs2_core::random::{rngs::StdRng, Rng, SeedableRng};
6use scirs2_core::RngExt;
7
8pub mod extensions;
9pub use extensions::*;
10
11pub mod advanced;
12pub use advanced::*;
13
14#[cfg(test)]
15mod tests;
16
17// ── SparsityMeasure ────────────────────────────────────────────────────────────
18
19/// Regularizer / sparsity measure for encoding.
20#[derive(Debug, Clone, PartialEq)]
21pub enum SparsityMeasure {
22    /// L0 pseudo-norm: at most `k` non-zero coefficients.
23    L0(usize),
24    /// L1 norm (LASSO) with penalty weight.
25    L1(f32),
26    /// Elastic Net with L1 and L2 penalty weights.
27    ElasticNet(f32, f32),
28}
29
30// ── SparseCode ────────────────────────────────────────────────────────────────
31
32/// A sparse representation of a signal: coefficients and their support indices.
33#[derive(Debug, Clone)]
34pub struct SparseCode {
35    /// Non-zero coefficients.
36    pub coefficients: Vec<f32>,
37    /// Indices of the non-zero entries (support).
38    pub support: Vec<usize>,
39}
40
41impl SparseCode {
42    /// Number of non-zero coefficients.
43    pub fn nnz(&self) -> usize {
44        self.coefficients.iter().filter(|&&c| c != 0.0).count()
45    }
46
47    /// Reconstruct the signal from the dictionary atoms.
48    pub fn reconstruct(&self, dictionary: &[Vec<f32>]) -> Vec<f32> {
49        if dictionary.is_empty() || self.support.is_empty() {
50            return Vec::new();
51        }
52        let dim = dictionary[0].len();
53        let mut out = vec![0.0_f32; dim];
54        for (&idx, &coef) in self.support.iter().zip(self.coefficients.iter()) {
55            if idx < dictionary.len() {
56                for (o, &a) in out.iter_mut().zip(dictionary[idx].iter()) {
57                    *o += coef * a;
58                }
59            }
60        }
61        out
62    }
63}
64
65// ── OrthogonalMatchingPursuit ─────────────────────────────────────────────────
66
67/// Greedy sparse encoder using Orthogonal Matching Pursuit.
68pub struct OrthogonalMatchingPursuit {
69    /// Number of non-zero coefficients (sparsity budget).
70    pub n_nonzero: usize,
71}
72
73impl OrthogonalMatchingPursuit {
74    /// Encode `signal` using at most `n_nonzero` dictionary atoms.
75    ///
76    /// Algorithm: iteratively select the most correlated atom with the residual,
77    /// then project the signal onto the current support via least-squares,
78    /// and update the residual.
79    pub fn encode(&self, signal: &[f32], dictionary: &[Vec<f32>]) -> SparseCode {
80        let n_atoms = dictionary.len();
81        let k = self.n_nonzero.min(n_atoms);
82        let mut residual = signal.to_vec();
83        let mut support: Vec<usize> = Vec::with_capacity(k);
84
85        for _ in 0..k {
86            // Select atom with highest correlation to residual
87            let mut best_idx = 0;
88            let mut best_corr = f32::NEG_INFINITY;
89            for (j, atom) in dictionary.iter().enumerate() {
90                if support.contains(&j) {
91                    continue;
92                }
93                let corr: f32 = dot(&residual, atom).abs();
94                if corr > best_corr {
95                    best_corr = corr;
96                    best_idx = j;
97                }
98            }
99            support.push(best_idx);
100
101            // Solve least squares: min ||signal - D_S * x||^2 via Gram-Schmidt / normal eqs
102            let coefficients = omp_least_squares(signal, dictionary, &support);
103
104            // Update residual
105            residual = signal.to_vec();
106            for (&idx, &c) in support.iter().zip(coefficients.iter()) {
107                for (r, &a) in residual.iter_mut().zip(dictionary[idx].iter()) {
108                    *r -= c * a;
109                }
110            }
111
112            // Early exit if residual is negligible
113            if norm_sq(&residual) < 1e-10 {
114                let coefs = omp_least_squares(signal, dictionary, &support);
115                return SparseCode {
116                    coefficients: coefs,
117                    support,
118                };
119            }
120        }
121
122        let coefficients = omp_least_squares(signal, dictionary, &support);
123        SparseCode {
124            coefficients,
125            support,
126        }
127    }
128}
129
130/// Solve the least-squares system signal ≈ D_S * x for selected support atoms.
131pub(crate) fn omp_least_squares(
132    signal: &[f32],
133    dictionary: &[Vec<f32>],
134    support: &[usize],
135) -> Vec<f32> {
136    let k = support.len();
137    let ds: Vec<Vec<f32>> = support.iter().map(|&i| dictionary[i].clone()).collect();
138    let mut gram = vec![vec![0.0_f32; k]; k];
139    for i in 0..k {
140        for j in 0..k {
141            gram[i][j] = dot(&ds[i], &ds[j]);
142        }
143    }
144    let rhs: Vec<f32> = ds.iter().map(|col| dot(col, signal)).collect();
145    cholesky_solve(&gram, &rhs).unwrap_or_else(|_| vec![0.0_f32; k])
146}
147
148/// Solve A x = b via Gaussian elimination (small systems).
149pub(crate) fn cholesky_solve(a: &[Vec<f32>], b: &[f32]) -> Result<Vec<f32>, ()> {
150    let n = b.len();
151    if n == 0 {
152        return Ok(Vec::new());
153    }
154    let mut m: Vec<Vec<f32>> = a
155        .iter()
156        .zip(b.iter())
157        .map(|(row, &bi)| {
158            let mut r = row.clone();
159            r.push(bi);
160            r
161        })
162        .collect();
163    for col in 0..n {
164        let mut pivot_row = col;
165        let mut max_val = m[col][col].abs();
166        for row in (col + 1)..n {
167            if m[row][col].abs() > max_val {
168                max_val = m[row][col].abs();
169                pivot_row = row;
170            }
171        }
172        if max_val < 1e-12 {
173            return Err(());
174        }
175        m.swap(col, pivot_row);
176        let diag = m[col][col];
177        for v in m[col].iter_mut() {
178            *v /= diag;
179        }
180        for row in 0..n {
181            if row == col {
182                continue;
183            }
184            let factor = m[row][col];
185            let pivot_row_copy = m[col].clone();
186            for (v, &p) in m[row].iter_mut().zip(pivot_row_copy.iter()) {
187                *v -= factor * p;
188            }
189        }
190    }
191    Ok(m.iter()
192        .map(|row| *row.last().unwrap_or(&0.0))
193        .collect())
194}
195
196// ── LassoEncoder ─────────────────────────────────────────────────────────────
197
198/// ISTA-based LASSO encoder.
199pub struct LassoEncoder {
200    /// L1 regularization weight.
201    pub lambda: f32,
202    /// Maximum number of ISTA iterations.
203    pub max_iter: usize,
204    /// Convergence tolerance.
205    pub tol: f32,
206}
207
208impl LassoEncoder {
209    /// ISTA: iterative soft-thresholding.
210    ///
211    /// Update: x ← soft_threshold(x + D^T(y - Dx), λ/L)
212    /// where L is the Lipschitz constant (largest eigenvalue of D^T D).
213    pub fn encode_ista(&self, signal: &[f32], dictionary: &[Vec<f32>]) -> SparseCode {
214        let n_atoms = dictionary.len();
215        if n_atoms == 0 {
216            return SparseCode {
217                coefficients: Vec::new(),
218                support: Vec::new(),
219            };
220        }
221
222        // Estimate Lipschitz constant via power iteration on D^T D
223        let lipschitz = estimate_lipschitz(dictionary);
224        let step = if lipschitz > 1e-10 {
225            1.0 / lipschitz
226        } else {
227            0.01
228        };
229        let threshold = self.lambda * step;
230
231        let mut x = vec![0.0_f32; n_atoms];
232        for _ in 0..self.max_iter {
233            let x_old = x.clone();
234            // residual = signal - D * x
235            let mut residual = signal.to_vec();
236            for (j, atom) in dictionary.iter().enumerate() {
237                let coef = x[j];
238                for (r, &a) in residual.iter_mut().zip(atom.iter()) {
239                    *r -= coef * a;
240                }
241            }
242            // gradient step: x += D^T * residual * step
243            for (j, atom) in dictionary.iter().enumerate() {
244                x[j] += dot(atom, &residual) * step;
245            }
246            // soft threshold
247            for v in x.iter_mut() {
248                *v = Self::soft_threshold(*v, threshold);
249            }
250            // convergence check
251            let change: f32 = x
252                .iter()
253                .zip(x_old.iter())
254                .map(|(a, b)| (a - b).powi(2))
255                .sum::<f32>()
256                .sqrt();
257            if change < self.tol {
258                break;
259            }
260        }
261
262        let support: Vec<usize> = x
263            .iter()
264            .enumerate()
265            .filter(|(_, &v)| v != 0.0)
266            .map(|(i, _)| i)
267            .collect();
268        let coefficients: Vec<f32> = support.iter().map(|&i| x[i]).collect();
269        SparseCode {
270            coefficients,
271            support,
272        }
273    }
274
275    /// Soft thresholding: sign(x) * max(|x| - threshold, 0).
276    #[inline]
277    pub fn soft_threshold(x: f32, threshold: f32) -> f32 {
278        if x > threshold {
279            x - threshold
280        } else if x < -threshold {
281            x + threshold
282        } else {
283            0.0
284        }
285    }
286}
287
288/// Estimate the Lipschitz constant of D^T D via power iteration.
289pub(crate) fn estimate_lipschitz(dictionary: &[Vec<f32>]) -> f32 {
290    let n = dictionary.len();
291    if n == 0 {
292        return 1.0;
293    }
294    let mut v = vec![1.0_f32 / (n as f32).sqrt(); n];
295    for _ in 0..20 {
296        // w = D^T D v
297        let mut w = vec![0.0_f32; n];
298        for (i, atom_i) in dictionary.iter().enumerate() {
299            for (j, atom_j) in dictionary.iter().enumerate() {
300                w[i] += dot(atom_i, atom_j) * v[j];
301            }
302        }
303        let nrm = norm(&w);
304        if nrm < 1e-12 {
305            return 1.0;
306        }
307        v = w.iter().map(|&x| x / nrm).collect();
308    }
309    // Rayleigh quotient
310    let mut w = vec![0.0_f32; n];
311    for (i, atom_i) in dictionary.iter().enumerate() {
312        for (j, atom_j) in dictionary.iter().enumerate() {
313            w[i] += dot(atom_i, atom_j) * v[j];
314        }
315    }
316    dot(&w, &v).max(1e-10)
317}
318
319// ── DictionaryLearning ────────────────────────────────────────────────────────
320
321/// Configuration for dictionary learning algorithms.
322#[derive(Debug, Clone)]
323pub struct DlConfig {
324    /// Number of dictionary atoms.
325    pub n_atoms: usize,
326    /// Sparsity target (max non-zero coefficients per encoding).
327    pub n_nonzero: usize,
328    /// Maximum training iterations.
329    pub max_iter: usize,
330    /// Convergence tolerance.
331    pub tol: f32,
332}
333
334/// A learned dictionary: a set of unit-norm atoms.
335#[derive(Debug, Clone)]
336pub struct Dictionary {
337    /// Dictionary atoms stored as row vectors.
338    pub atoms: Vec<Vec<f32>>,
339    /// Number of atoms in the dictionary.
340    pub n_atoms: usize,
341    /// Dimensionality of each atom.
342    pub atom_dim: usize,
343}
344
345impl Dictionary {
346    /// Create a new dictionary with given atoms.
347    pub fn new(atoms: Vec<Vec<f32>>) -> Self {
348        let n_atoms = atoms.len();
349        let atom_dim = atoms.first().map(|a| a.len()).unwrap_or(0);
350        Self {
351            atoms,
352            n_atoms,
353            atom_dim,
354        }
355    }
356
357    /// Get the i-th atom.
358    pub fn atom(&self, i: usize) -> &[f32] {
359        &self.atoms[i]
360    }
361
362    /// Normalize all atoms to unit L2 norm.
363    pub fn normalize_atoms(&mut self) {
364        for atom in self.atoms.iter_mut() {
365            let n = norm(atom);
366            if n > 1e-10 {
367                for v in atom.iter_mut() {
368                    *v /= n;
369                }
370            }
371        }
372    }
373
374    /// Mutual coherence: max |<a_i, a_j>| for i ≠ j.
375    pub fn coherence(&self) -> f32 {
376        let n = self.atoms.len();
377        let mut max_corr = 0.0_f32;
378        for i in 0..n {
379            for j in (i + 1)..n {
380                let c = dot(&self.atoms[i], &self.atoms[j]).abs();
381                if c > max_corr {
382                    max_corr = c;
383                }
384            }
385        }
386        max_corr
387    }
388}
389
390// ── K-SVD ─────────────────────────────────────────────────────────────────────
391
392/// K-SVD dictionary learning.
393pub struct KSvd {
394    /// Configuration for dictionary learning.
395    pub config: DlConfig,
396}
397
398impl KSvd {
399    /// Fit a dictionary to `data` using the K-SVD algorithm.
400    ///
401    /// Alternates between:
402    /// 1. Sparse coding step (OMP)
403    /// 2. Dictionary update step (rank-1 SVD for each atom)
404    pub fn fit(&self, data: &[Vec<f32>]) -> (Dictionary, Vec<SparseCode>) {
405        if data.is_empty() {
406            let d = Dictionary {
407                atoms: Vec::new(),
408                n_atoms: 0,
409                atom_dim: 0,
410            };
411            return (d, Vec::new());
412        }
413        let signal_dim = data[0].len();
414        let n_atoms = self.config.n_atoms;
415        let n_samples = data.len();
416
417        // Initialize dictionary randomly from data samples
418        let mut rng = StdRng::seed_from_u64(42);
419        let mut atoms: Vec<Vec<f32>> = (0..n_atoms)
420            .map(|i| {
421                let sample = &data[i % n_samples];
422                let mut atom = sample.clone();
423                // Add tiny jitter for distinct atoms
424                for v in atom.iter_mut() {
425                    *v += (rng.random::<f32>() - 0.5) * 1e-4;
426                }
427                atom
428            })
429            .collect();
430        // Normalize
431        for atom in atoms.iter_mut() {
432            let n = norm(atom);
433            if n > 1e-10 {
434                for v in atom.iter_mut() {
435                    *v /= n;
436                }
437            }
438        }
439
440        let omp = OrthogonalMatchingPursuit {
441            n_nonzero: self.config.n_nonzero,
442        };
443        let mut codes: Vec<SparseCode> = data.iter().map(|s| omp.encode(s, &atoms)).collect();
444
445        for _iter in 0..self.config.max_iter {
446            let old_atoms = atoms.clone();
447
448            // Dictionary update: for each atom k, update using rank-1 SVD
449            for k in 0..n_atoms {
450                // Find samples that use atom k
451                let using: Vec<usize> = (0..n_samples)
452                    .filter(|&i| codes[i].support.contains(&k))
453                    .collect();
454                if using.is_empty() {
455                    continue;
456                }
457
458                // E_k = data - sum_{j≠k} c_j * a_j
459                let e_k: Vec<Vec<f32>> = using
460                    .iter()
461                    .map(|&i| {
462                        let mut e = data[i].clone();
463                        for (&sup_idx, &coef) in
464                            codes[i].support.iter().zip(codes[i].coefficients.iter())
465                        {
466                            if sup_idx == k {
467                                continue;
468                            }
469                            for (ev, &av) in e.iter_mut().zip(atoms[sup_idx].iter()) {
470                                *ev -= coef * av;
471                            }
472                        }
473                        e
474                    })
475                    .collect();
476                let (u, sigma, v) = rank1_svd(&e_k, signal_dim, 20);
477                atoms[k] = u;
478                for (sample_pos, &sample_idx) in using.iter().enumerate() {
479                    let new_coef = sigma * v[sample_pos];
480                    // Find position of k in the support
481                    if let Some(pos) = codes[sample_idx].support.iter().position(|&s| s == k) {
482                        codes[sample_idx].coefficients[pos] = new_coef;
483                    }
484                }
485            }
486
487            // Check convergence
488            let change: f32 = atoms
489                .iter()
490                .zip(old_atoms.iter())
491                .map(|(a, b)| {
492                    a.iter()
493                        .zip(b.iter())
494                        .map(|(x, y)| (x - y).powi(2))
495                        .sum::<f32>()
496                })
497                .sum::<f32>()
498                .sqrt();
499            if change < self.config.tol {
500                break;
501            }
502
503            // Re-encode all signals
504            codes = data.iter().map(|s| omp.encode(s, &atoms)).collect();
505        }
506
507        let dict = Dictionary::new(atoms);
508        (dict, codes)
509    }
510}
511
512/// Rank-1 SVD via power iteration. Returns (u, sigma, v).
513pub(crate) fn rank1_svd(
514    matrix: &[Vec<f32>],
515    _signal_dim: usize,
516    max_iter: usize,
517) -> (Vec<f32>, f32, Vec<f32>) {
518    let m = matrix.len();
519    if m == 0 {
520        return (Vec::new(), 0.0, Vec::new());
521    }
522    let n = matrix[0].len();
523    let mut v: Vec<f32> = (0..m).map(|i| if i == 0 { 1.0 } else { 0.0 }).collect();
524    let mut u = vec![0.0_f32; n];
525    for _ in 0..max_iter {
526        for i in 0..n {
527            u[i] = matrix
528                .iter()
529                .zip(v.iter())
530                .map(|(row, &vj)| row[i] * vj)
531                .sum();
532        }
533        let sigma = norm(&u);
534        if sigma < 1e-12 {
535            return (vec![0.0; n], 0.0, vec![0.0; m]);
536        }
537        for x in u.iter_mut() {
538            *x /= sigma;
539        }
540
541        // v = matrix^T * u = sum_i u[i] * matrix[j][i]
542        for j in 0..m {
543            v[j] = u
544                .iter()
545                .zip(matrix[j].iter())
546                .map(|(&ui, &aij)| ui * aij)
547                .sum();
548        }
549        let vnorm = norm(&v);
550        if vnorm < 1e-12 {
551            break;
552        }
553        for x in v.iter_mut() {
554            *x /= vnorm;
555        }
556    }
557
558    let mut mv = vec![0.0_f32; n];
559    for i in 0..n {
560        mv[i] = matrix
561            .iter()
562            .zip(v.iter())
563            .map(|(row, &vj)| row[i] * vj)
564            .sum();
565    }
566    let sigma = dot(&u, &mv);
567    (u, sigma, v)
568}
569
570// ── OnlineDictionaryLearning ──────────────────────────────────────────────────
571
572/// Online dictionary learning (Mairal et al. 2009).
573pub struct OnlineDictionaryLearning {
574    /// Dictionary learning configuration.
575    pub config: DlConfig,
576    /// A accumulator: n_atoms × n_atoms
577    pub a_matrix: Vec<Vec<f32>>,
578    /// B accumulator: signal_dim × n_atoms
579    pub b_matrix: Vec<Vec<f32>>,
580}
581
582impl OnlineDictionaryLearning {
583    /// Create a new online dictionary learner.
584    pub fn new(config: DlConfig, signal_dim: usize) -> Self {
585        let n_atoms = config.n_atoms;
586        let a_matrix = vec![vec![0.0_f32; n_atoms]; n_atoms];
587        let b_matrix = vec![vec![0.0_f32; n_atoms]; signal_dim];
588        Self {
589            config,
590            a_matrix,
591            b_matrix,
592        }
593    }
594
595    /// Process one sample and return updated dictionary.
596    pub fn update(&mut self, sample: &[f32], current_dict: &Dictionary) -> Dictionary {
597        let omp = OrthogonalMatchingPursuit {
598            n_nonzero: self.config.n_nonzero,
599        };
600        let code = omp.encode(sample, &current_dict.atoms);
601
602        let n_atoms = self.config.n_atoms;
603        let signal_dim = sample.len();
604
605        // Build dense coefficient vector
606        let mut alpha = vec![0.0_f32; n_atoms];
607        for (&idx, &coef) in code.support.iter().zip(code.coefficients.iter()) {
608            if idx < n_atoms {
609                alpha[idx] = coef;
610            }
611        }
612
613        // Update A: A += alpha * alpha^T
614        for i in 0..n_atoms {
615            for j in 0..n_atoms {
616                self.a_matrix[i][j] += alpha[i] * alpha[j];
617            }
618        }
619        // Update B: B += sample * alpha^T  (signal_dim × n_atoms)
620        for i in 0..signal_dim.min(self.b_matrix.len()) {
621            for j in 0..n_atoms {
622                self.b_matrix[i][j] += sample[i] * alpha[j];
623            }
624        }
625
626        // Update dictionary via block-coordinate descent
627        let mut new_atoms = current_dict.atoms.clone();
628        for k in 0..n_atoms {
629            let a_kk = self.a_matrix[k][k];
630            if a_kk.abs() < 1e-10 {
631                continue;
632            }
633
634            // u_k = (b_k - D * a_k + d_k * a_kk) / a_kk
635            let mut u_k = vec![0.0_f32; signal_dim];
636            for i in 0..signal_dim.min(self.b_matrix.len()) {
637                u_k[i] = self.b_matrix[i][k];
638            }
639            for j in 0..n_atoms {
640                if j == k {
641                    continue;
642                }
643                let a_jk = self.a_matrix[j][k];
644                for i in 0..signal_dim.min(new_atoms[j].len()) {
645                    u_k[i] -= new_atoms[j][i] * a_jk;
646                }
647            }
648            for v in u_k.iter_mut() {
649                *v /= a_kk;
650            }
651
652            // Normalize
653            let n = norm(&u_k);
654            if n > 1e-10 {
655                for v in u_k.iter_mut() {
656                    *v /= n;
657                }
658            }
659            new_atoms[k] = u_k;
660        }
661
662        Dictionary::new(new_atoms)
663    }
664
665    /// Fit a dictionary on a stream of samples.
666    pub fn fit_stream(&mut self, data: &[Vec<f32>]) -> Dictionary {
667        if data.is_empty() {
668            return Dictionary {
669                atoms: Vec::new(),
670                n_atoms: 0,
671                atom_dim: 0,
672            };
673        }
674        let signal_dim = data[0].len();
675        let n_atoms = self.config.n_atoms;
676
677        // Initialize dictionary from random samples
678        let mut rng = StdRng::seed_from_u64(123);
679        let mut atoms: Vec<Vec<f32>> = (0..n_atoms)
680            .map(|i| {
681                let mut atom = data[i % data.len()].clone();
682                for v in atom.iter_mut() {
683                    *v += (rng.random::<f32>() - 0.5) * 1e-3;
684                }
685                atom
686            })
687            .collect();
688        for atom in atoms.iter_mut() {
689            let n = norm(atom);
690            if n > 1e-10 {
691                for v in atom.iter_mut() {
692                    *v /= n;
693                }
694            }
695        }
696
697        let mut dict = Dictionary::new(atoms);
698        for sample in data.iter() {
699            dict = self.update(sample, &dict);
700        }
701        dict
702    }
703}
704
705// ── CompressedSensing ─────────────────────────────────────────────────────────
706
707/// Types of measurement matrices for compressed sensing.
708#[derive(Debug, Clone)]
709pub enum MeasurementMatrix {
710    /// Gaussian random matrix: m measurements, n signal length.
711    Gaussian(usize, usize),
712    /// Bernoulli ±1 random matrix: m measurements, n signal length.
713    Bernoulli(usize, usize),
714    /// DCT-based measurement matrix with m measurements.
715    Dct(usize),
716}
717
718/// Result of compressive measurement.
719#[derive(Debug, Clone)]
720pub struct CsMeasurement {
721    /// Compressed measurement vector.
722    pub y: Vec<f32>,
723    /// Number of measurements taken.
724    pub n_measurements: usize,
725    /// Original signal length.
726    pub signal_len: usize,
727}
728
729/// Measurement matrix stored row by row.
730#[derive(Debug, Clone)]
731pub struct CsMatrix {
732    /// Matrix rows (each row is one measurement vector).
733    pub rows: Vec<Vec<f32>>,
734}
735
736impl CsMatrix {
737    /// Number of rows (measurements).
738    pub fn nrows(&self) -> usize {
739        self.rows.len()
740    }
741    /// Number of columns (signal dimension).
742    pub fn ncols(&self) -> usize {
743        self.rows.first().map(|r| r.len()).unwrap_or(0)
744    }
745
746    /// Compute A^T * v  (ncols-dim result).
747    pub(crate) fn transpose_mul(&self, v: &[f32]) -> Vec<f32> {
748        let ncols = self.ncols();
749        let mut result = vec![0.0_f32; ncols];
750        for (row, &vi) in self.rows.iter().zip(v.iter()) {
751            for (r, &a) in result.iter_mut().zip(row.iter()) {
752                *r += a * vi;
753            }
754        }
755        result
756    }
757
758    /// Compute A * v  (nrows-dim result).
759    pub(crate) fn forward_mul(&self, v: &[f32]) -> Vec<f32> {
760        self.rows.iter().map(|row| dot(row, v)).collect()
761    }
762}
763
764/// Generate a measurement matrix from a specification.
765pub fn generate_matrix(m: &MeasurementMatrix, rng: &mut impl Rng) -> CsMatrix {
766    match m {
767        MeasurementMatrix::Gaussian(rows, cols) => {
768            let scale = 1.0 / (*rows as f32).sqrt();
769            let rows_data: Vec<Vec<f32>> = (0..*rows)
770                .map(|_| (0..*cols).map(|_| box_muller(rng) * scale).collect())
771                .collect();
772            CsMatrix { rows: rows_data }
773        }
774        MeasurementMatrix::Bernoulli(rows, cols) => {
775            let scale = 1.0 / (*rows as f32).sqrt();
776            let rows_data: Vec<Vec<f32>> = (0..*rows)
777                .map(|_| {
778                    (0..*cols)
779                        .map(|_| {
780                            if rng.random::<f32>() > 0.5 {
781                                scale
782                            } else {
783                                -scale
784                            }
785                        })
786                        .collect()
787                })
788                .collect();
789            CsMatrix { rows: rows_data }
790        }
791        MeasurementMatrix::Dct(m_rows) => {
792            // Use first m_rows rows from the DCT matrix (size estimated as 2*m_rows)
793            let n = (2 * m_rows).max(4);
794            let dct_rows: Vec<Vec<f32>> = (0..*m_rows)
795                .map(|k| {
796                    (0..n)
797                        .map(|j| {
798                            let scale = if k == 0 {
799                                (1.0 / n as f32).sqrt()
800                            } else {
801                                (2.0 / n as f32).sqrt()
802                            };
803                            let angle = std::f32::consts::PI * k as f32 * (2 * j + 1) as f32
804                                / (2 * n) as f32;
805                            scale * angle.cos()
806                        })
807                        .collect()
808                })
809                .collect();
810            CsMatrix { rows: dct_rows }
811        }
812    }
813}
814
815/// Box-Muller transform for standard normal samples.
816pub(crate) fn box_muller(rng: &mut impl Rng) -> f32 {
817    let u1 = (rng.random::<f32>()).max(1e-30);
818    let u2 = rng.random::<f32>();
819    (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos()
820}
821
822/// Measure a signal using a compressive sensing matrix.
823pub fn measure(signal: &[f32], matrix: &CsMatrix) -> CsMeasurement {
824    let signal_len = signal.len();
825    let y = matrix.forward_mul(signal);
826    let n_measurements = y.len();
827    CsMeasurement {
828        y,
829        n_measurements,
830        signal_len,
831    }
832}
833
834// ── BasisPursuit (ADMM) ───────────────────────────────────────────────────────
835
836/// Basis Pursuit via ADMM: minimize ||x||_1 subject to Ax = y.
837pub struct BasisPursuit {
838    /// Maximum ADMM iterations.
839    pub max_iter: usize,
840    /// ADMM penalty parameter rho.
841    pub rho: f32,
842    /// Primal/dual residual convergence tolerance.
843    pub tol: f32,
844}
845
846impl BasisPursuit {
847    /// Recover the sparse signal from compressed measurements using ADMM.
848    pub fn recover(&self, measurement: &CsMeasurement, matrix: &CsMatrix) -> Vec<f32> {
849        let n = measurement.signal_len;
850        let rho = self.rho;
851
852        // ADMM variables
853        let mut x = vec![0.0_f32; n];
854        let mut z = vec![0.0_f32; n];
855        let mut u = vec![0.0_f32; n]; // scaled dual
856
857        // Precompute A^T * y
858        let aty: Vec<f32> = matrix.transpose_mul(&measurement.y);
859
860        for _ in 0..self.max_iter {
861            // x-update: (A^T A + rho * I) x = A^T y + rho * (z - u)
862            // Use conjugate gradient (small/medium problems)
863            let rhs: Vec<f32> = (0..n).map(|i| aty[i] + rho * (z[i] - u[i])).collect();
864            x = admm_cg(matrix, rho, &rhs, &x, 50);
865
866            // z-update: soft threshold
867            let z_old = z.clone();
868            for i in 0..n {
869                z[i] = LassoEncoder::soft_threshold(x[i] + u[i], 1.0 / rho);
870            }
871
872            // u-update
873            for i in 0..n {
874                u[i] += x[i] - z[i];
875            }
876
877            // Primal and dual residual
878            let primal: f32 = (0..n).map(|i| (x[i] - z[i]).powi(2)).sum::<f32>().sqrt();
879            let dual: f32 = (0..n)
880                .map(|i| (rho * (z[i] - z_old[i])).powi(2))
881                .sum::<f32>()
882                .sqrt();
883            if primal < self.tol && dual < self.tol {
884                break;
885            }
886        }
887        z
888    }
889}
890
891/// Conjugate gradient for (A^T A + rho * I) x = b.
892fn admm_cg(matrix: &CsMatrix, rho: f32, b: &[f32], x0: &[f32], max_iter: usize) -> Vec<f32> {
893    let n = b.len();
894    let mut x = x0.to_vec();
895    // r = b - (A^T A + rho I) x0
896    let ax = matrix.forward_mul(&x);
897    let atax = matrix.transpose_mul(&ax);
898    let mut r: Vec<f32> = (0..n).map(|i| b[i] - atax[i] - rho * x[i]).collect();
899    let mut p = r.clone();
900    let mut rsold: f32 = r.iter().map(|&v| v * v).sum();
901
902    for _ in 0..max_iter {
903        if rsold < 1e-10 {
904            break;
905        }
906        let ap_full = matrix.forward_mul(&p);
907        let atp = matrix.transpose_mul(&ap_full);
908        let ap: Vec<f32> = (0..n).map(|i| atp[i] + rho * p[i]).collect();
909        let denom: f32 = p.iter().zip(ap.iter()).map(|(&pi, &api)| pi * api).sum();
910        if denom.abs() < 1e-14 {
911            break;
912        }
913        let alpha = rsold / denom;
914        for i in 0..n {
915            x[i] += alpha * p[i];
916            r[i] -= alpha * ap[i];
917        }
918        let rsnew: f32 = r.iter().map(|&v| v * v).sum();
919        let beta = rsnew / rsold.max(1e-14);
920        for i in 0..n {
921            p[i] = r[i] + beta * p[i];
922        }
923        rsold = rsnew;
924    }
925    x
926}
927
928// ── CoSaMP ────────────────────────────────────────────────────────────────────
929
930/// Compressive Sampling Matching Pursuit.
931pub struct CoSaMP {
932    /// Sparsity level (number of non-zeros to recover).
933    pub n_nonzero: usize,
934    /// Maximum iterations.
935    pub max_iter: usize,
936}
937
938impl CoSaMP {
939    /// Recover signal via CoSaMP: support extension + LS + pruning.
940    pub fn recover(&self, measurement: &CsMeasurement, matrix: &CsMatrix) -> Vec<f32> {
941        let n = measurement.signal_len;
942        let s = self.n_nonzero;
943
944        let mut x = vec![0.0_f32; n];
945
946        for _ in 0..self.max_iter {
947            // Proxy: A^T * (y - Ax)
948            let ax = matrix.forward_mul(&x);
949            let residual: Vec<f32> = measurement
950                .y
951                .iter()
952                .zip(ax.iter())
953                .map(|(&y, &a)| y - a)
954                .collect();
955            let proxy = matrix.transpose_mul(&residual);
956
957            // Identify 2s largest components of proxy
958            let mut indices: Vec<usize> = (0..n).collect();
959            indices.sort_by(|&a, &b| {
960                proxy[b]
961                    .abs()
962                    .partial_cmp(&proxy[a].abs())
963                    .unwrap_or(std::cmp::Ordering::Equal)
964            });
965            let mut support: Vec<usize> = indices[..((2 * s).min(n))].to_vec();
966
967            // Merge with current support
968            let current_nonzero: Vec<usize> = x
969                .iter()
970                .enumerate()
971                .filter(|(_, &v)| v != 0.0)
972                .map(|(i, _)| i)
973                .collect();
974            for idx in current_nonzero {
975                if !support.contains(&idx) {
976                    support.push(idx);
977                }
978            }
979
980            // Least squares on merged support
981            let coefs = cosamp_ls(&measurement.y, matrix, &support);
982
983            // Prune to s largest
984            let mut coef_pairs: Vec<(usize, f32)> = support
985                .iter()
986                .zip(coefs.iter())
987                .map(|(&i, &c)| (i, c))
988                .collect();
989            coef_pairs.sort_by(|a, b| {
990                b.1.abs()
991                    .partial_cmp(&a.1.abs())
992                    .unwrap_or(std::cmp::Ordering::Equal)
993            });
994            coef_pairs.truncate(s);
995
996            x = vec![0.0_f32; n];
997            for (i, c) in coef_pairs {
998                x[i] = c;
999            }
1000
1001            // Convergence check
1002            let ax_new = matrix.forward_mul(&x);
1003            let res_norm: f32 = measurement
1004                .y
1005                .iter()
1006                .zip(ax_new.iter())
1007                .map(|(&y, &a)| (y - a).powi(2))
1008                .sum::<f32>()
1009                .sqrt();
1010            if res_norm < 1e-6 {
1011                break;
1012            }
1013        }
1014        x
1015    }
1016}
1017
1018/// Least squares for CoSaMP on a subset of columns.
1019fn cosamp_ls(y: &[f32], matrix: &CsMatrix, support: &[usize]) -> Vec<f32> {
1020    let k = support.len();
1021    let m = y.len();
1022    // Build A_S (m × k)
1023    let as_cols: Vec<Vec<f32>> = support
1024        .iter()
1025        .map(|&i| {
1026            matrix
1027                .rows
1028                .iter()
1029                .map(|row| *row.get(i).unwrap_or(&0.0))
1030                .collect()
1031        })
1032        .collect();
1033    // Solve normal equations: A_S^T A_S x = A_S^T y
1034    let mut gram = vec![vec![0.0_f32; k]; k];
1035    for i in 0..k {
1036        for j in 0..k {
1037            gram[i][j] = (0..m).map(|r| as_cols[i][r] * as_cols[j][r]).sum();
1038        }
1039    }
1040    let rhs: Vec<f32> = (0..k)
1041        .map(|i| (0..m).map(|r| as_cols[i][r] * y[r]).sum())
1042        .collect();
1043    cholesky_solve(&gram, &rhs).unwrap_or_else(|_| vec![0.0_f32; k])
1044}
1045
1046/// Estimate the restricted isometry property constant via random sparse vectors.
1047pub fn rip_constant_estimate(
1048    matrix: &CsMatrix,
1049    s: usize,
1050    n_trials: usize,
1051    rng: &mut impl Rng,
1052) -> f32 {
1053    let n = matrix.ncols();
1054    if n == 0 {
1055        return 0.0;
1056    }
1057    let mut max_delta = 0.0_f32;
1058
1059    for _ in 0..n_trials {
1060        // Generate random s-sparse unit vector
1061        let mut support: Vec<usize> = (0..n).collect();
1062        // Shuffle first s elements
1063        for i in 0..s.min(n) {
1064            let j = i + (rng.random_range(0..(n - i)));
1065            support.swap(i, j);
1066        }
1067        let support = &support[..s.min(n)];
1068
1069        let mut x = vec![0.0_f32; n];
1070        let mut total_sq = 0.0_f32;
1071        for &i in support {
1072            let v = rng.random::<f32>() * 2.0 - 1.0;
1073            x[i] = v;
1074            total_sq += v * v;
1075        }
1076        if total_sq < 1e-10 {
1077            continue;
1078        }
1079        let x_norm_sq = total_sq;
1080        for v in x.iter_mut() {
1081            *v /= x_norm_sq.sqrt();
1082        }
1083
1084        let ax = matrix.forward_mul(&x);
1085        let ax_norm_sq: f32 = ax.iter().map(|&v| v * v).sum();
1086        let delta = (ax_norm_sq - 1.0).abs();
1087        if delta > max_delta {
1088            max_delta = delta;
1089        }
1090    }
1091    max_delta
1092}
1093
1094// ── SparseAutoencoder ─────────────────────────────────────────────────────────
1095
1096/// Configuration for the sparse autoencoder.
1097#[derive(Debug, Clone)]
1098pub struct SaeConfig {
1099    /// Input dimensionality.
1100    pub input_dim: usize,
1101    /// Hidden layer dimensionality.
1102    pub hidden_dim: usize,
1103    /// Target sparsity level (fraction of active units).
1104    pub sparsity_target: f32,
1105    /// Weight for the sparsity penalty.
1106    pub sparsity_weight: f32,
1107}
1108
1109/// Sparse autoencoder with k-winner-takes-all activation.
1110#[derive(Debug, Clone)]
1111pub struct SparseAutoencoder {
1112    /// Encoder weight matrix (hidden_dim × input_dim).
1113    pub encoder_w: Vec<Vec<f32>>,
1114    /// Encoder bias vector.
1115    pub encoder_b: Vec<f32>,
1116    /// Decoder weight matrix (input_dim × hidden_dim).
1117    pub decoder_w: Vec<Vec<f32>>,
1118    /// Decoder bias vector.
1119    pub decoder_b: Vec<f32>,
1120}
1121
1122impl SparseAutoencoder {
1123    /// Create a new sparse autoencoder with Xavier initialization.
1124    pub fn new(config: &SaeConfig, rng: &mut impl Rng) -> Self {
1125        let scale_enc = (2.0 / (config.input_dim + config.hidden_dim) as f32).sqrt();
1126        let scale_dec = (2.0 / (config.hidden_dim + config.input_dim) as f32).sqrt();
1127
1128        let encoder_w: Vec<Vec<f32>> = (0..config.hidden_dim)
1129            .map(|_| {
1130                (0..config.input_dim)
1131                    .map(|_| box_muller(rng) * scale_enc)
1132                    .collect()
1133            })
1134            .collect();
1135        let encoder_b = vec![0.0_f32; config.hidden_dim];
1136
1137        let decoder_w: Vec<Vec<f32>> = (0..config.input_dim)
1138            .map(|_| {
1139                (0..config.hidden_dim)
1140                    .map(|_| box_muller(rng) * scale_dec)
1141                    .collect()
1142            })
1143            .collect();
1144        let decoder_b = vec![0.0_f32; config.input_dim];
1145
1146        Self {
1147            encoder_w,
1148            encoder_b,
1149            decoder_w,
1150            decoder_b,
1151        }
1152    }
1153
1154    /// Encode with ReLU followed by k-winner-takes-all.
1155    pub fn encode(&self, x: &[f32]) -> Vec<f32> {
1156        let hidden_dim = self.encoder_w.len();
1157        // Linear + ReLU
1158        let mut h: Vec<f32> = (0..hidden_dim)
1159            .map(|i| {
1160                let pre_act = dot(&self.encoder_w[i], x) + self.encoder_b[i];
1161                pre_act.max(0.0) // ReLU
1162            })
1163            .collect();
1164
1165        // k-winner-takes-all: keep top-k activations, zero the rest
1166        let k = ((hidden_dim as f32 * 0.1).ceil() as usize)
1167            .max(1)
1168            .min(hidden_dim);
1169        let mut indexed: Vec<(usize, f32)> = h.iter().cloned().enumerate().collect();
1170        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1171        let threshold_idx = k;
1172        let zero_indices: Vec<usize> = indexed[threshold_idx..].iter().map(|(i, _)| *i).collect();
1173        for i in zero_indices {
1174            h[i] = 0.0;
1175        }
1176        h
1177    }
1178
1179    /// Decode the latent representation back to the input space.
1180    pub fn decode(&self, z: &[f32]) -> Vec<f32> {
1181        let input_dim = self.decoder_w.len();
1182        (0..input_dim)
1183            .map(|i| dot(&self.decoder_w[i], z) + self.decoder_b[i])
1184            .collect()
1185    }
1186
1187    /// Mean squared reconstruction error.
1188    pub fn reconstruction_loss(&self, x: &[f32]) -> f32 {
1189        let z = self.encode(x);
1190        let x_hat = self.decode(&z);
1191        let n = x.len();
1192        if n == 0 {
1193            return 0.0;
1194        }
1195        x.iter()
1196            .zip(x_hat.iter())
1197            .map(|(&a, &b)| (a - b).powi(2))
1198            .sum::<f32>()
1199            / n as f32
1200    }
1201
1202    /// KL-divergence sparsity penalty on hidden activations.
1203    pub fn sparsity_loss(&self, z: &[f32]) -> f32 {
1204        if z.is_empty() {
1205            return 0.0;
1206        }
1207        let rho_hat = z.iter().map(|&v| v.clamp(0.0, 1.0)).sum::<f32>() / z.len() as f32;
1208        let rho = 0.05_f32; // default target
1209        Self::kl_divergence_bernoulli(rho, rho_hat)
1210    }
1211
1212    /// Total loss = reconstruction + sparsity_weight * sparsity.
1213    pub fn total_loss(&self, x: &[f32], config: &SaeConfig) -> f32 {
1214        let z = self.encode(x);
1215        let rec = self.reconstruction_loss(x);
1216        let spar = self.sparsity_loss_with_target(&z, config.sparsity_target);
1217        rec + config.sparsity_weight * spar
1218    }
1219
1220    fn sparsity_loss_with_target(&self, z: &[f32], target: f32) -> f32 {
1221        if z.is_empty() {
1222            return 0.0;
1223        }
1224        let rho_hat = z.iter().map(|&v| v.clamp(0.0, 1.0)).sum::<f32>() / z.len() as f32;
1225        Self::kl_divergence_bernoulli(target, rho_hat)
1226    }
1227
1228    /// KL divergence between Bernoulli(rho) and Bernoulli(rho_hat).
1229    pub fn kl_divergence_bernoulli(rho: f32, rho_hat: f32) -> f32 {
1230        let eps = 1e-8_f32;
1231        let rho = rho.clamp(eps, 1.0 - eps);
1232        let rho_hat = rho_hat.clamp(eps, 1.0 - eps);
1233        rho * (rho / rho_hat).ln() + (1.0 - rho) * ((1.0 - rho) / (1.0 - rho_hat)).ln()
1234    }
1235}
1236
1237// ── SparseLearningMetrics ─────────────────────────────────────────────────────
1238
1239/// Normalized L2 reconstruction error.
1240pub fn reconstruction_error(original: &[f32], reconstructed: &[f32]) -> f32 {
1241    if original.is_empty() {
1242        return 0.0;
1243    }
1244    let err: f32 = original
1245        .iter()
1246        .zip(reconstructed.iter())
1247        .map(|(a, b)| (a - b).powi(2))
1248        .sum::<f32>();
1249    let norm_orig: f32 = original.iter().map(|v| v * v).sum::<f32>();
1250    if norm_orig < 1e-10 {
1251        return err.sqrt();
1252    }
1253    (err / norm_orig).sqrt()
1254}
1255
1256/// Sparsity ratio: nnz / signal_len.
1257pub fn sparsity_ratio(code: &SparseCode, signal_len: usize) -> f32 {
1258    if signal_len == 0 {
1259        return 0.0;
1260    }
1261    code.nnz() as f32 / signal_len as f32
1262}
1263
1264/// Welch bound (coherence lower bound): sqrt((n-d) / (d*(n-1))).
1265pub fn coherence_bound(n_atoms: usize, signal_dim: usize) -> f32 {
1266    let n = n_atoms as f32;
1267    let d = signal_dim as f32;
1268    if n <= 1.0 || d <= 0.0 || n <= d {
1269        return 0.0;
1270    }
1271    ((n - d) / (d * (n - 1.0))).sqrt()
1272}
1273
1274/// Recovery quality in PSNR (dB): 10 * log10(peak² / MSE).
1275pub fn recovery_quality(original: &[f32], recovered: &[f32]) -> f32 {
1276    if original.is_empty() {
1277        return 0.0;
1278    }
1279    let mse: f32 = original
1280        .iter()
1281        .zip(recovered.iter())
1282        .map(|(a, b)| (a - b).powi(2))
1283        .sum::<f32>()
1284        / original.len() as f32;
1285    if mse < 1e-14 {
1286        return 100.0;
1287    }
1288    let peak = original
1289        .iter()
1290        .cloned()
1291        .fold(0.0_f32, |acc, v| acc.max(v.abs()));
1292    let peak = if peak < 1e-10 { 1.0 } else { peak };
1293    10.0 * (peak * peak / mse).log10()
1294}
1295
1296/// Summary report for sparse learning evaluation.
1297#[derive(Debug, Clone)]
1298pub struct SparseLearningReport {
1299    /// Average reconstruction error across all signals.
1300    pub reconstruction_error: f32,
1301    /// Average sparsity ratio.
1302    pub sparsity: f32,
1303    /// Dictionary coherence (max off-diagonal Gram matrix entry).
1304    pub coherence: f32,
1305    /// Average signal-to-noise ratio in dB.
1306    pub snr_db: f32,
1307}
1308
1309/// Evaluate a dictionary and sparse coding on a dataset.
1310pub fn evaluate(data: &[Vec<f32>], dict: &Dictionary, n_nonzero: usize) -> SparseLearningReport {
1311    let omp = OrthogonalMatchingPursuit { n_nonzero };
1312    let mut total_rec_err = 0.0_f32;
1313    let mut total_sparsity = 0.0_f32;
1314    let mut total_snr = 0.0_f32;
1315    let n = data.len();
1316
1317    for signal in data.iter() {
1318        let code = omp.encode(signal, &dict.atoms);
1319        let reconstructed = code.reconstruct(&dict.atoms);
1320        let signal_len = signal.len();
1321        total_rec_err += reconstruction_error(signal, &reconstructed);
1322        total_sparsity += sparsity_ratio(&code, signal_len);
1323        total_snr += recovery_quality(signal, &reconstructed);
1324    }
1325
1326    let (rec_err, sparsity, snr) = if n > 0 {
1327        (
1328            total_rec_err / n as f32,
1329            total_sparsity / n as f32,
1330            total_snr / n as f32,
1331        )
1332    } else {
1333        (0.0, 0.0, 0.0)
1334    };
1335
1336    SparseLearningReport {
1337        reconstruction_error: rec_err,
1338        sparsity,
1339        coherence: dict.coherence(),
1340        snr_db: snr,
1341    }
1342}
1343
1344// ── Helpers ───────────────────────────────────────────────────────────────────
1345
1346#[inline]
1347pub(crate) fn dot(a: &[f32], b: &[f32]) -> f32 {
1348    a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
1349}
1350
1351#[inline]
1352pub(crate) fn norm(v: &[f32]) -> f32 {
1353    v.iter().map(|&x| x * x).sum::<f32>().sqrt()
1354}
1355
1356#[inline]
1357pub(crate) fn norm_sq(v: &[f32]) -> f32 {
1358    v.iter().map(|&x| x * x).sum()
1359}