Skip to main content

tenflowers_neural/sparse_learning/
advanced.rs

1//! Advanced sparse learning algorithms: Sparse Transformers, Structured Sparsity,
2//! Sparse Coding Layers, and Compressed Sensing utilities.
3
4use scirs2_core::random::{rngs::StdRng, Rng, SeedableRng};
5use scirs2_core::RngExt;
6use super::{box_muller, dot, norm, LassoEncoder};
7
8// ── BigBirdAttention ──────────────────────────────────────────────────────────
9
10/// BigBird attention pattern (Zaheer et al. 2020): block-sparse + random + global.
11///
12/// Combines three types of attention:
13/// - Block-sparse: local sliding window of blocks
14/// - Random: random token connections for long-range coverage
15/// - Global: special tokens that attend to/from all positions
16pub struct BigBirdAttention {
17    /// Sequence length (number of tokens).
18    pub seq_len: usize,
19    /// Number of attention heads.
20    pub n_heads: usize,
21    /// Head dimension.
22    pub head_dim: usize,
23    /// Block size for block-sparse attention.
24    pub block_size: usize,
25    /// Number of random attention connections per token.
26    pub n_random: usize,
27    /// Number of global tokens (e.g., CLS, SEP).
28    pub n_global: usize,
29}
30
31impl BigBirdAttention {
32    /// Generate the BigBird attention mask for a sequence.
33    ///
34    /// Returns a boolean matrix of shape (seq_len × seq_len) where `true`
35    /// means the query at row i can attend to key at column j.
36    pub fn compute_attention_mask(&self, rng: &mut impl Rng) -> Vec<Vec<bool>> {
37        let n = self.seq_len;
38        let mut mask = vec![vec![false; n]; n];
39
40        // 1) Block-sparse: each token attends to its local block neighbours
41        let bs = self.block_size.max(1);
42        for i in 0..n {
43            let block_i = i / bs;
44            // Attend to tokens in the adjacent blocks (block_i-1, block_i, block_i+1)
45            let start_block = block_i.saturating_sub(1);
46            let end_block = (block_i + 2).min((n + bs - 1) / bs);
47            for b in start_block..end_block {
48                let start_tok = b * bs;
49                let end_tok = ((b + 1) * bs).min(n);
50                for j in start_tok..end_tok {
51                    mask[i][j] = true;
52                }
53            }
54        }
55
56        // 2) Random: each token additionally attends to `n_random` random tokens
57        for i in 0..n {
58            for _ in 0..self.n_random {
59                let j = rng.random_range(0..n);
60                mask[i][j] = true;
61            }
62        }
63
64        // 3) Global: first n_global tokens attend to/from all tokens
65        for g in 0..self.n_global.min(n) {
66            for j in 0..n {
67                mask[g][j] = true;
68                mask[j][g] = true;
69            }
70        }
71
72        mask
73    }
74
75    /// Compute sparse attention output for a single head.
76    ///
77    /// `q`, `k`, `v` are (seq_len × head_dim) matrices stored row-major.
78    /// Only positions allowed by `mask` contribute to the attention sum.
79    pub fn sparse_attention(
80        &self,
81        q: &[Vec<f32>],
82        k: &[Vec<f32>],
83        v: &[Vec<f32>],
84        mask: &[Vec<bool>],
85    ) -> Vec<Vec<f32>> {
86        let n = self.seq_len.min(q.len());
87        let d = self.head_dim as f32;
88        let scale = 1.0 / d.sqrt();
89        let mut output = vec![vec![0.0_f32; self.head_dim]; n];
90
91        for i in 0..n {
92            // Gather allowed keys for query i
93            let allowed: Vec<usize> = (0..n).filter(|&j| mask[i].get(j).copied().unwrap_or(false)).collect();
94            if allowed.is_empty() {
95                continue;
96            }
97
98            // Compute scaled dot-product scores
99            let scores: Vec<f32> = allowed
100                .iter()
101                .map(|&j| dot(&q[i], &k[j]) * scale)
102                .collect();
103
104            // Softmax over allowed positions
105            let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
106            let exp_scores: Vec<f32> = scores.iter().map(|&s| (s - max_score).exp()).collect();
107            let sum_exp: f32 = exp_scores.iter().sum::<f32>().max(1e-9);
108
109            // Weighted sum of values
110            for (pos, &j) in allowed.iter().enumerate() {
111                let weight = exp_scores[pos] / sum_exp;
112                for (out, &vval) in output[i].iter_mut().zip(v[j].iter()) {
113                    *out += weight * vval;
114                }
115            }
116        }
117        output
118    }
119
120    /// Count the number of active attention connections in a mask.
121    pub fn mask_sparsity(mask: &[Vec<bool>]) -> f32 {
122        let total = (mask.len() * mask.first().map(|r| r.len()).unwrap_or(0)) as f32;
123        if total < 1.0 {
124            return 0.0;
125        }
126        let active: usize = mask.iter().flat_map(|row| row.iter()).filter(|&&b| b).count();
127        active as f32 / total
128    }
129}
130
131// ── SparseSlidingWindowAttention ──────────────────────────────────────────────
132
133/// Sparse sliding-window attention pattern (Beltagy et al. Longformer 2020): local window + global tokens.
134///
135/// Local window attention ensures O(n·w) complexity instead of O(n²).
136/// Named `SparseSlidingWindowAttention` to avoid conflict with `LongformerAttention` in `moe_scaling`.
137pub struct SparseSlidingWindowAttention {
138    /// Sequence length.
139    pub seq_len: usize,
140    /// Number of attention heads.
141    pub n_heads: usize,
142    /// Head dimension.
143    pub head_dim: usize,
144    /// One-sided window radius (attends to `window_size` tokens left and right).
145    pub window_size: usize,
146    /// Indices of global tokens (attend to/from all positions).
147    pub global_tokens: Vec<usize>,
148}
149
150impl SparseSlidingWindowAttention {
151    /// Compute the sliding-window + global attention mask.
152    pub fn compute_attention_mask(&self) -> Vec<Vec<bool>> {
153        let n = self.seq_len;
154        let mut mask = vec![vec![false; n]; n];
155
156        // Sliding window
157        for i in 0..n {
158            let start = i.saturating_sub(self.window_size);
159            let end = (i + self.window_size + 1).min(n);
160            for j in start..end {
161                mask[i][j] = true;
162            }
163        }
164
165        // Global tokens
166        for &g in &self.global_tokens {
167            if g < n {
168                for j in 0..n {
169                    mask[g][j] = true;
170                    mask[j][g] = true;
171                }
172            }
173        }
174
175        mask
176    }
177
178    /// Compute the sparse attention output for a single head using the local window.
179    ///
180    /// `q`, `k`, `v` are (seq_len × head_dim) row vectors.
181    pub fn local_attention(
182        &self,
183        q: &[Vec<f32>],
184        k: &[Vec<f32>],
185        v: &[Vec<f32>],
186    ) -> Vec<Vec<f32>> {
187        let n = self.seq_len.min(q.len());
188        let scale = 1.0 / (self.head_dim as f32).sqrt();
189        let global_set: std::collections::HashSet<usize> = self.global_tokens.iter().cloned().collect();
190        let mut output = vec![vec![0.0_f32; self.head_dim]; n];
191
192        for i in 0..n {
193            let mut attend_to: Vec<usize> = Vec::new();
194
195            // Local window
196            let start = i.saturating_sub(self.window_size);
197            let end = (i + self.window_size + 1).min(n);
198            for j in start..end {
199                attend_to.push(j);
200            }
201
202            // Global tokens (always attend)
203            for &g in &self.global_tokens {
204                if g < n && !attend_to.contains(&g) {
205                    attend_to.push(g);
206                }
207            }
208
209            // If current token is global, attend to all
210            if global_set.contains(&i) {
211                attend_to = (0..n).collect();
212            }
213
214            let scores: Vec<f32> = attend_to.iter().map(|&j| dot(&q[i], &k[j]) * scale).collect();
215            let max_s = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
216            let exps: Vec<f32> = scores.iter().map(|&s| (s - max_s).exp()).collect();
217            let sum_e = exps.iter().sum::<f32>().max(1e-9);
218
219            for (pos, &j) in attend_to.iter().enumerate() {
220                let w = exps[pos] / sum_e;
221                for (o, &vv) in output[i].iter_mut().zip(v[j].iter()) {
222                    *o += w * vv;
223                }
224            }
225        }
226        output
227    }
228}
229
230// ── SparseAttentionRouter ─────────────────────────────────────────────────────
231
232/// Learned sparse attention router: selects the top-k keys each query attends to.
233///
234/// The router learns a per-query scoring function (small MLP) to predict which
235/// keys are most relevant, then applies top-k selection.
236pub struct SparseAttentionRouter {
237    /// Sequence length.
238    pub seq_len: usize,
239    /// Token/key dimension.
240    pub dim: usize,
241    /// Number of keys each query attends to.
242    pub top_k: usize,
243    /// Router MLP weight matrix (dim × dim).
244    pub router_w: Vec<Vec<f32>>,
245}
246
247impl SparseAttentionRouter {
248    /// Initialize the router with random weights.
249    pub fn new(seq_len: usize, dim: usize, top_k: usize, rng: &mut impl Rng) -> Self {
250        let scale = (2.0 / (dim + dim) as f32).sqrt();
251        let router_w: Vec<Vec<f32>> = (0..dim)
252            .map(|_| (0..dim).map(|_| box_muller(rng) * scale).collect())
253            .collect();
254        Self { seq_len, dim, top_k, router_w }
255    }
256
257    /// Compute top-k routing: for each query, select the `top_k` most similar keys.
258    ///
259    /// Returns (selected_indices, attention_weights) for each query.
260    pub fn route(&self, queries: &[Vec<f32>], keys: &[Vec<f32>]) -> Vec<(Vec<usize>, Vec<f32>)> {
261        let n_q = queries.len().min(self.seq_len);
262        let n_k = keys.len();
263        let k = self.top_k.min(n_k);
264
265        queries[..n_q].iter().map(|q| {
266            // Project query through router weight
267            let q_proj: Vec<f32> = self.router_w.iter().map(|row| dot(row, q)).collect();
268
269            // Score each key
270            let mut scores: Vec<(usize, f32)> = keys.iter().enumerate().map(|(j, key)| {
271                let score = dot(&q_proj, key) / (self.dim as f32).sqrt();
272                (j, score)
273            }).collect();
274
275            // Top-k selection
276            scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
277            scores.truncate(k);
278
279            // Softmax over top-k
280            let max_s = scores.iter().map(|(_, s)| *s).fold(f32::NEG_INFINITY, f32::max);
281            let exps: Vec<f32> = scores.iter().map(|(_, s)| (s - max_s).exp()).collect();
282            let sum_e = exps.iter().sum::<f32>().max(1e-9);
283            let weights: Vec<f32> = exps.iter().map(|e| e / sum_e).collect();
284            let indices: Vec<usize> = scores.iter().map(|(i, _)| *i).collect();
285
286            (indices, weights)
287        }).collect()
288    }
289
290    /// Compute sparse attention output using the router.
291    pub fn attend(&self, queries: &[Vec<f32>], keys: &[Vec<f32>], values: &[Vec<f32>]) -> Vec<Vec<f32>> {
292        let routing = self.route(queries, keys);
293        routing.into_iter().map(|(indices, weights)| {
294            let mut out = vec![0.0_f32; self.dim];
295            for (idx, w) in indices.iter().zip(weights.iter()) {
296                if *idx < values.len() {
297                    for (o, &v) in out.iter_mut().zip(values[*idx].iter()) {
298                        *o += w * v;
299                    }
300                }
301            }
302            out
303        }).collect()
304    }
305}
306
307// ── SparsePositionEncoding ────────────────────────────────────────────────────
308
309/// Sparse relative position bias for attention (Raffel et al. T5-style, sparse variant).
310///
311/// Rather than dense relative position embeddings, only a subset of relative
312/// distances receive learned biases (sparsifying the position information).
313pub struct SparsePositionEncoding {
314    /// Maximum sequence length.
315    pub max_seq_len: usize,
316    /// Embedding dimension.
317    pub dim: usize,
318    /// Number of distinct relative distance buckets.
319    pub n_buckets: usize,
320    /// Bucket embedding table (n_buckets × dim).
321    pub bucket_embeddings: Vec<Vec<f32>>,
322}
323
324impl SparsePositionEncoding {
325    /// Initialize with random bucket embeddings.
326    pub fn new(max_seq_len: usize, dim: usize, n_buckets: usize, rng: &mut impl Rng) -> Self {
327        let scale = (1.0 / dim as f32).sqrt();
328        let bucket_embeddings = (0..n_buckets)
329            .map(|_| (0..dim).map(|_| box_muller(rng) * scale).collect())
330            .collect();
331        Self { max_seq_len, dim, n_buckets, bucket_embeddings }
332    }
333
334    /// Map relative distance to a bucket index (logarithmic bucketing for large gaps).
335    pub fn relative_position_bucket(&self, rel_pos: i32) -> usize {
336        let abs_pos = rel_pos.unsigned_abs() as usize;
337        let half = self.n_buckets / 2;
338        if abs_pos < half {
339            // Small distances: exact bucket
340            abs_pos.min(half.saturating_sub(1))
341        } else {
342            // Large distances: logarithmic bucket
343            let log_bucket = (abs_pos as f32 / half as f32).ln() / (self.max_seq_len as f32 / half as f32).ln().max(1e-8);
344            let bucket = half + (log_bucket * half as f32) as usize;
345            bucket.min(self.n_buckets - 1)
346        }
347    }
348
349    /// Compute position bias matrix: shape (seq_len × seq_len × dim).
350    /// Returns a flat vector of (seq_len × seq_len) bucket indices.
351    pub fn position_bias_indices(&self, seq_len: usize) -> Vec<usize> {
352        let n = seq_len.min(self.max_seq_len);
353        let mut indices = vec![0usize; n * n];
354        for i in 0..n {
355            for j in 0..n {
356                let rel = j as i32 - i as i32;
357                indices[i * n + j] = self.relative_position_bucket(rel);
358            }
359        }
360        indices
361    }
362
363    /// Look up position embeddings for a (seq_len × seq_len) attention matrix.
364    pub fn get_position_biases(&self, seq_len: usize) -> Vec<Vec<f32>> {
365        let indices = self.position_bias_indices(seq_len);
366        indices.iter().map(|&b| self.bucket_embeddings[b].clone()).collect()
367    }
368}
369
370// ── GroupLasso ────────────────────────────────────────────────────────────────
371
372/// Group LASSO: apply L1 penalty over group norms for structured sparsity.
373///
374/// The penalty is λ * Σ_g ‖w_g‖₂, promoting whole groups to be zero.
375/// Uses proximal gradient descent with exact group-prox operator.
376pub struct GroupLasso {
377    /// Feature groups: list of groups, each containing feature indices.
378    pub groups: Vec<Vec<usize>>,
379    /// Group L2 regularization weight.
380    pub lambda: f32,
381    /// Learning rate (proximal gradient step size).
382    pub learning_rate: f32,
383    /// Maximum optimization iterations.
384    pub max_iter: usize,
385}
386
387impl GroupLasso {
388    /// Fit Group LASSO via proximal gradient descent on squared loss.
389    pub fn fit(&self, x_data: &[Vec<f32>], y: &[f32]) -> Vec<f32> {
390        if x_data.is_empty() || y.is_empty() {
391            return Vec::new();
392        }
393        let n_features = x_data[0].len();
394        let n_samples = x_data.len();
395        let mut w = vec![0.0_f32; n_features];
396
397        for _ in 0..self.max_iter {
398            // Compute gradient of squared loss: ∇ = (1/n) X^T (Xw - y)
399            let preds: Vec<f32> = x_data.iter().map(|row| dot(row, &w)).collect();
400            let residuals: Vec<f32> = preds.iter().zip(y.iter()).map(|(&p, &yi)| p - yi).collect();
401            let mut grad = vec![0.0_f32; n_features];
402            for (i, row) in x_data.iter().enumerate() {
403                for (j, &xij) in row.iter().enumerate() {
404                    grad[j] += xij * residuals[i] / n_samples as f32;
405                }
406            }
407
408            // Gradient step
409            let mut w_new = w.clone();
410            for j in 0..n_features {
411                w_new[j] -= self.learning_rate * grad[j];
412            }
413
414            // Group proximal operator: prox_{λ‖·‖₂}(w_g) = max(0, 1 - λ/‖w_g‖) w_g
415            for group in &self.groups {
416                let group_norm: f32 = group.iter()
417                    .filter(|&&j| j < n_features)
418                    .map(|&j| w_new[j] * w_new[j])
419                    .sum::<f32>()
420                    .sqrt();
421                let threshold = self.lambda * self.learning_rate;
422                if group_norm < threshold {
423                    for &j in group.iter().filter(|&&j| j < n_features) {
424                        w_new[j] = 0.0;
425                    }
426                } else {
427                    let scale = 1.0 - threshold / group_norm;
428                    for &j in group.iter().filter(|&&j| j < n_features) {
429                        w_new[j] *= scale;
430                    }
431                }
432            }
433            w = w_new;
434        }
435        w
436    }
437
438    /// Predict target values using fitted weights.
439    pub fn predict(&self, x_data: &[Vec<f32>], w: &[f32]) -> Vec<f32> {
440        x_data.iter().map(|row| dot(row, w)).collect()
441    }
442
443    /// Compute the fraction of groups that are entirely zero.
444    pub fn group_sparsity(&self, w: &[f32]) -> f32 {
445        if self.groups.is_empty() {
446            return 0.0;
447        }
448        let zero_groups = self.groups.iter().filter(|group| {
449            group.iter().all(|&j| j >= w.len() || w[j].abs() < 1e-8)
450        }).count();
451        zero_groups as f32 / self.groups.len() as f32
452    }
453}
454
455// ── StructuredPruningMask ─────────────────────────────────────────────────────
456
457/// N:M sparsity mask for structured sparsity (e.g., 2:4 for Ampere GPUs).
458///
459/// In N:M sparsity, exactly N out of every M consecutive elements are non-zero.
460/// This achieves exactly (1 - N/M) structured sparsity that maps efficiently to
461/// hardware accelerators.
462pub struct StructuredPruningMask {
463    /// Number of non-zero elements to retain per group.
464    pub n_keep: usize,
465    /// Group size M.
466    pub group_size: usize,
467}
468
469impl StructuredPruningMask {
470    /// Create a 2:4 sparsity mask (default for Ampere/Hopper GPUs).
471    pub fn nm_24() -> Self {
472        Self { n_keep: 2, group_size: 4 }
473    }
474
475    /// Apply N:M sparsity to a weight tensor.
476    ///
477    /// Within every M consecutive elements, only the N largest by magnitude are kept.
478    /// Returns the pruned weight vector and the binary mask.
479    pub fn apply(&self, weights: &[f32]) -> (Vec<f32>, Vec<bool>) {
480        let n = weights.len();
481        let mut pruned = weights.to_vec();
482        let mut mask = vec![false; n];
483        let m = self.group_size;
484        let k = self.n_keep.min(m);
485
486        let n_groups = (n + m - 1) / m;
487        for g in 0..n_groups {
488            let start = g * m;
489            let end = (start + m).min(n);
490            let group_len = end - start;
491
492            // Sort indices by magnitude within group
493            let mut indices: Vec<usize> = (start..end).collect();
494            indices.sort_by(|&a, &b| {
495                weights[b].abs().partial_cmp(&weights[a].abs()).unwrap_or(std::cmp::Ordering::Equal)
496            });
497
498            // Keep top-k, zero the rest
499            let keep_count = k.min(group_len);
500            for i in 0..group_len {
501                if i < keep_count {
502                    mask[indices[i]] = true;
503                } else {
504                    pruned[indices[i]] = 0.0;
505                }
506            }
507        }
508        (pruned, mask)
509    }
510
511    /// Compute the actual sparsity ratio of the mask.
512    pub fn sparsity(mask: &[bool]) -> f32 {
513        if mask.is_empty() {
514            return 0.0;
515        }
516        let zeros = mask.iter().filter(|&&b| !b).count();
517        zeros as f32 / mask.len() as f32
518    }
519}
520
521// ── ChannelPruner ─────────────────────────────────────────────────────────────
522
523/// Iterative channel importance scoring and pruning for CNNs.
524///
525/// Supports three importance criteria:
526/// - L1 norm of filter weights
527/// - Taylor first-order approximation (weight × gradient)
528/// - FPGM (Filter Pruning via Geometric Median)
529#[derive(Debug, Clone)]
530pub enum ChannelImportanceCriterion {
531    /// Sum of absolute weight values per filter.
532    L1Norm,
533    /// First-order Taylor: |w| × |∂L/∂w| approximation.
534    TaylorExpansion,
535    /// Filter Pruning via Geometric Median (He et al. 2019).
536    Fpgm,
537}
538
539/// Channel pruner that scores and removes channels from a layer.
540pub struct ChannelPruner {
541    /// Fraction of channels to prune.
542    pub prune_ratio: f32,
543    /// Importance scoring criterion.
544    pub criterion: ChannelImportanceCriterion,
545}
546
547impl ChannelPruner {
548    /// Compute importance scores for each filter (channel).
549    ///
550    /// `filters` is a list of filters, each represented as a flat weight vector.
551    /// `gradients` is optional; required for Taylor criterion.
552    pub fn compute_importance(
553        &self,
554        filters: &[Vec<f32>],
555        gradients: Option<&[Vec<f32>]>,
556    ) -> Vec<f32> {
557        match &self.criterion {
558            ChannelImportanceCriterion::L1Norm => {
559                filters.iter().map(|f| f.iter().map(|w| w.abs()).sum::<f32>()).collect()
560            }
561            ChannelImportanceCriterion::TaylorExpansion => {
562                let grads = gradients.unwrap_or(filters);
563                filters.iter().zip(grads.iter().chain(std::iter::repeat(&filters[0]))).map(|(f, g)| {
564                    f.iter().zip(g.iter()).map(|(w, dw)| (w * dw).abs()).sum::<f32>()
565                }).collect()
566            }
567            ChannelImportanceCriterion::Fpgm => {
568                // FPGM: importance = negative of minimum distance to geometric median
569                // Approximate via pairwise distances: less important = closer to median cluster
570                let n = filters.len();
571                (0..n).map(|i| {
572                    // Score = inverse of average pairwise L2 distance to other filters
573                    // (higher distance from the median cluster → keep)
574                    let sum_dist: f32 = (0..n).filter(|&j| j != i).map(|j| {
575                        filters[i].iter().zip(filters[j].iter()).map(|(a, b)| (a - b).powi(2)).sum::<f32>().sqrt()
576                    }).sum();
577                    if n <= 1 { norm(&filters[i]) } else { sum_dist / (n - 1) as f32 }
578                }).collect()
579            }
580        }
581    }
582
583    /// Select which channels to prune given importance scores.
584    ///
585    /// Returns a boolean mask (true = keep, false = prune).
586    pub fn prune_mask(&self, importance: &[f32]) -> Vec<bool> {
587        let n = importance.len();
588        let n_prune = ((n as f32 * self.prune_ratio).round() as usize).min(n);
589        let mut indexed: Vec<(usize, f32)> = importance.iter().cloned().enumerate().collect();
590        indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
591        let mut mask = vec![true; n];
592        for (idx, _) in indexed.iter().take(n_prune) {
593            mask[*idx] = false;
594        }
595        mask
596    }
597
598    /// Apply pruning: zero out pruned channel weights.
599    pub fn apply_pruning(&self, filters: &[Vec<f32>], mask: &[bool]) -> Vec<Vec<f32>> {
600        filters.iter().zip(mask.iter()).map(|(f, &keep)| {
601            if keep { f.clone() } else { vec![0.0_f32; f.len()] }
602        }).collect()
603    }
604}
605
606// ── LayerPruner ───────────────────────────────────────────────────────────────
607
608/// Whole-layer importance estimation via Fisher information approximation.
609///
610/// Uses the empirical Fisher (squared gradients) to estimate how much each
611/// layer contributes to the loss. Layers with near-zero Fisher information
612/// can be removed with minimal accuracy impact.
613pub struct LayerPruner {
614    /// Threshold below which a layer is considered unimportant.
615    pub importance_threshold: f32,
616}
617
618impl LayerPruner {
619    /// Estimate layer importance from weight and gradient vectors.
620    ///
621    /// Fisher information approximation: I(θ) ≈ (1/n) Σ (∂L/∂θ)²
622    /// A layer is deemed prunable if its Fisher trace is below `importance_threshold`.
623    pub fn layer_importance(&self, weights: &[f32], gradients: &[f32]) -> f32 {
624        if weights.is_empty() {
625            return 0.0;
626        }
627        let n = weights.len().min(gradients.len());
628        // Fisher trace: sum of squared gradients, weighted by weight magnitude
629        let fisher: f32 = (0..n).map(|i| (gradients[i] * weights[i]).powi(2)).sum::<f32>() / n as f32;
630        fisher
631    }
632
633    /// Determine which layers should be pruned given their importances.
634    ///
635    /// Returns a boolean vector (true = keep, false = prune).
636    pub fn select_layers_to_prune(&self, importances: &[f32]) -> Vec<bool> {
637        importances.iter().map(|&imp| imp >= self.importance_threshold).collect()
638    }
639
640    /// Compute sensitivity ranking of layers.
641    ///
642    /// Returns indices sorted from least to most important.
643    pub fn rank_layers(&self, importances: &[f32]) -> Vec<usize> {
644        let mut indexed: Vec<(usize, f32)> = importances.iter().cloned().enumerate().collect();
645        indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
646        indexed.iter().map(|(i, _)| *i).collect()
647    }
648}
649
650// ── SparseCodingLayer ─────────────────────────────────────────────────────────
651
652/// Sparse coding layer replacing a dense linear layer.
653///
654/// Learns a dictionary D and encodes inputs via LISTA (Learned ISTA, Gregor 2010).
655/// The forward pass runs a fixed number of ISTA steps with learned step sizes,
656/// producing sparse activations.
657pub struct SparseCodingLayer {
658    /// Input dimension.
659    pub input_dim: usize,
660    /// Dictionary size (number of atoms / output sparse codes).
661    pub dict_size: usize,
662    /// Number of unrolled ISTA steps.
663    pub n_steps: usize,
664    /// L1 regularization weight.
665    pub lambda: f32,
666    /// Dictionary matrix D: (dict_size × input_dim).
667    pub dictionary: Vec<Vec<f32>>,
668    /// Learned step size per ISTA iteration (one per step).
669    pub step_sizes: Vec<f32>,
670    /// Recurrent weight matrix We: (dict_size × dict_size) = I - step * D^T D.
671    pub we_matrix: Vec<Vec<f32>>,
672}
673
674impl SparseCodingLayer {
675    /// Initialize with random dictionary (Xavier) and default step sizes.
676    pub fn new(input_dim: usize, dict_size: usize, n_steps: usize, lambda: f32, rng: &mut impl Rng) -> Self {
677        let scale = (2.0 / (input_dim + dict_size) as f32).sqrt();
678        let dictionary: Vec<Vec<f32>> = (0..dict_size)
679            .map(|_| (0..input_dim).map(|_| box_muller(rng) * scale).collect())
680            .collect();
681
682        // Normalize dictionary atoms
683        let dictionary: Vec<Vec<f32>> = dictionary.into_iter().map(|mut atom| {
684            let n = norm(&atom);
685            if n > 1e-8 { for v in atom.iter_mut() { *v /= n; } }
686            atom
687        }).collect();
688
689        // Default step size heuristic: 1 / (largest singular value estimate)
690        let default_step = 0.1_f32;
691        let step_sizes = vec![default_step; n_steps];
692
693        // We = I - step * D^T D  (approximate; simplified as identity for init)
694        let we_matrix = (0..dict_size)
695            .map(|i| (0..dict_size).map(|j| if i == j { 1.0 - default_step } else { 0.0 }).collect())
696            .collect();
697
698        Self { input_dim, dict_size, n_steps, lambda, dictionary, step_sizes, we_matrix }
699    }
700
701    /// Forward pass: encode input via LISTA (unrolled ISTA).
702    ///
703    /// Returns sparse code z of dimension dict_size.
704    pub fn forward(&self, x: &[f32]) -> Vec<f32> {
705        // Initial response: D^T x
706        let mut z: Vec<f32> = self.dictionary.iter()
707            .map(|atom| dot(atom, x))
708            .collect();
709
710        // Threshold initial response
711        let init_thresh = self.lambda * self.step_sizes.first().copied().unwrap_or(0.1);
712        for v in z.iter_mut() {
713            *v = LassoEncoder::soft_threshold(*v, init_thresh);
714        }
715
716        // Unrolled ISTA steps
717        for step in 0..self.n_steps {
718            let step_size = self.step_sizes.get(step).copied().unwrap_or(0.1);
719            let threshold = self.lambda * step_size;
720
721            // Compute We * z + step * D^T * x
722            let we_z: Vec<f32> = self.we_matrix.iter().map(|row| dot(row, &z)).collect();
723            let dtx: Vec<f32> = self.dictionary.iter().map(|atom| dot(atom, x) * step_size).collect();
724
725            let mut z_new: Vec<f32> = (0..self.dict_size).map(|i| we_z[i] + dtx[i]).collect();
726            for v in z_new.iter_mut() {
727                *v = LassoEncoder::soft_threshold(*v, threshold);
728            }
729            z = z_new;
730        }
731        z
732    }
733
734    /// Reconstruct the input from its sparse code.
735    pub fn reconstruct(&self, z: &[f32]) -> Vec<f32> {
736        let mut x_hat = vec![0.0_f32; self.input_dim];
737        for (atom, &zi) in self.dictionary.iter().zip(z.iter()) {
738            for (xh, &a) in x_hat.iter_mut().zip(atom.iter()) {
739                *xh += zi * a;
740            }
741        }
742        x_hat
743    }
744
745    /// Reconstruction loss: ‖x - D z‖₂² / input_dim.
746    pub fn reconstruction_loss(&self, x: &[f32]) -> f32 {
747        let z = self.forward(x);
748        let x_hat = self.reconstruct(&z);
749        x.iter().zip(x_hat.iter()).map(|(a, b)| (a - b).powi(2)).sum::<f32>()
750            / self.input_dim as f32
751    }
752}
753
754// ── ListaNetwork ──────────────────────────────────────────────────────────────
755
756/// Learned ISTA (LISTA) network: unrolled gradient descent for sparse coding.
757///
758/// Implements the full LISTA architecture (Gregor & LeCun 2010) with
759/// learnable forward and recurrent weight matrices Wd and We.
760pub struct ListaNetwork {
761    /// Input dimension.
762    pub input_dim: usize,
763    /// Sparse code dimension (number of dictionary atoms).
764    pub code_dim: usize,
765    /// Number of unrolled iterations (network depth).
766    pub n_layers: usize,
767    /// L1 threshold per layer.
768    pub thresholds: Vec<f32>,
769    /// Forward weights Wd: (code_dim × input_dim) — maps input to initial response.
770    pub wd: Vec<Vec<f32>>,
771    /// Recurrent weights We: (code_dim × code_dim) — maps previous code to next.
772    pub we: Vec<Vec<f32>>,
773}
774
775impl ListaNetwork {
776    /// Initialize LISTA network with Xavier weights.
777    pub fn new(input_dim: usize, code_dim: usize, n_layers: usize, lambda: f32, rng: &mut impl Rng) -> Self {
778        let scale_wd = (2.0 / (input_dim + code_dim) as f32).sqrt();
779        let scale_we = (2.0 / (code_dim + code_dim) as f32).sqrt();
780        let wd = (0..code_dim).map(|_| (0..input_dim).map(|_| box_muller(rng) * scale_wd).collect()).collect();
781        let we = (0..code_dim).map(|_| (0..code_dim).map(|_| box_muller(rng) * scale_we).collect()).collect();
782        let thresholds = vec![lambda; n_layers];
783        Self { input_dim, code_dim, n_layers, thresholds, wd, we }
784    }
785
786    /// Forward pass through the unrolled ISTA network.
787    ///
788    /// z_{t+1} = soft_threshold(We z_t + Wd x, θ_t)
789    pub fn forward(&self, x: &[f32]) -> Vec<f32> {
790        let mut z = vec![0.0_f32; self.code_dim];
791        let init: Vec<f32> = self.wd.iter().map(|row| dot(row, x)).collect();
792        let thresh0 = self.thresholds.first().copied().unwrap_or(0.1);
793        for (zi, &init_i) in z.iter_mut().zip(init.iter()) {
794            *zi = LassoEncoder::soft_threshold(init_i, thresh0);
795        }
796
797        for layer in 1..self.n_layers {
798            let thresh = self.thresholds.get(layer).copied().unwrap_or(0.1);
799            let we_z: Vec<f32> = self.we.iter().map(|row| dot(row, &z)).collect();
800            let wd_x: Vec<f32> = self.wd.iter().map(|row| dot(row, x)).collect();
801            let mut z_new: Vec<f32> = (0..self.code_dim).map(|i| we_z[i] + wd_x[i]).collect();
802            for v in z_new.iter_mut() {
803                *v = LassoEncoder::soft_threshold(*v, thresh);
804            }
805            z = z_new;
806        }
807        z
808    }
809
810    /// Sparsity of the output code (fraction of zero entries).
811    pub fn output_sparsity(&self, x: &[f32]) -> f32 {
812        let z = self.forward(x);
813        let zeros = z.iter().filter(|&&v| v == 0.0).count();
814        zeros as f32 / self.code_dim as f32
815    }
816}
817
818// ── PredictiveCodingLayer ─────────────────────────────────────────────────────
819
820/// Predictive coding layer (Rao & Ballard 1999).
821///
822/// Maintains two populations of units:
823/// - Representation units r: the current estimate of the hidden state
824/// - Error units e: the prediction error (actual - predicted)
825///
826/// Learning minimizes the sum of squared prediction errors across layers.
827pub struct PredictiveCodingLayer {
828    /// Input dimension.
829    pub input_dim: usize,
830    /// Hidden representation dimension.
831    pub hidden_dim: usize,
832    /// Top-down prediction weight matrix (input_dim × hidden_dim).
833    pub prediction_w: Vec<Vec<f32>>,
834    /// Representation update rate.
835    pub r_lr: f32,
836    /// Number of inference iterations to converge r.
837    pub n_inference_steps: usize,
838}
839
840impl PredictiveCodingLayer {
841    /// Initialize with random prediction weights.
842    pub fn new(input_dim: usize, hidden_dim: usize, rng: &mut impl Rng) -> Self {
843        let scale = (1.0 / hidden_dim as f32).sqrt();
844        let prediction_w = (0..input_dim)
845            .map(|_| (0..hidden_dim).map(|_| box_muller(rng) * scale).collect())
846            .collect();
847        Self { input_dim, hidden_dim, prediction_w, r_lr: 0.1, n_inference_steps: 20 }
848    }
849
850    /// Compute top-down prediction from hidden representation.
851    pub fn predict(&self, r: &[f32]) -> Vec<f32> {
852        self.prediction_w.iter().map(|row| dot(row, r)).collect()
853    }
854
855    /// Compute prediction error: actual - predicted.
856    pub fn prediction_error(&self, actual: &[f32], r: &[f32]) -> Vec<f32> {
857        let pred = self.predict(r);
858        actual.iter().zip(pred.iter()).map(|(a, p)| a - p).collect()
859    }
860
861    /// Run inference to find the optimal hidden representation r* for input x.
862    ///
863    /// Gradient descent on the free energy:
864    /// ∂F/∂r = W^T e - r  (with unit precision assumption)
865    pub fn infer(&self, x: &[f32]) -> (Vec<f32>, Vec<f32>) {
866        let mut r = vec![0.0_f32; self.hidden_dim];
867
868        for _ in 0..self.n_inference_steps {
869            let e = self.prediction_error(x, &r);
870            // Gradient: W^T e
871            let grad_r: Vec<f32> = (0..self.hidden_dim).map(|j| {
872                self.prediction_w.iter().zip(e.iter()).map(|(row, &ei)| row[j] * ei).sum::<f32>()
873            }).collect();
874            for (ri, &g) in r.iter_mut().zip(grad_r.iter()) {
875                *ri += self.r_lr * g;
876            }
877        }
878        let e_final = self.prediction_error(x, &r);
879        (r, e_final)
880    }
881
882    /// Compute the free energy (total squared prediction error).
883    pub fn free_energy(&self, x: &[f32]) -> f32 {
884        let (_, e) = self.infer(x);
885        e.iter().map(|v| v * v).sum::<f32>() / self.input_dim as f32
886    }
887}
888
889// ── CompressedSensingMatrix ───────────────────────────────────────────────────
890
891/// RIP-satisfying random measurement matrices for compressed sensing.
892///
893/// Provides Gaussian, Bernoulli, and structured (subsampled DFT / circulant)
894/// matrices with theoretical RIP guarantees for appropriate (m, n, s) settings.
895#[derive(Debug, Clone)]
896pub enum CsMatrixType {
897    /// i.i.d. Gaussian N(0, 1/m) entries — RIP holds with m = O(s log(n/s)).
898    Gaussian,
899    /// i.i.d. Bernoulli ±1/sqrt(m) entries — same RIP bound as Gaussian.
900    Bernoulli,
901    /// Subsampled Hadamard (Walsh-Hadamard) matrix for structured sparsity.
902    /// Only valid when n is a power of 2.
903    SubsampledHadamard,
904}
905
906/// Compressed sensing measurement matrix with RIP analysis.
907pub struct CompressedSensingMatrix {
908    /// Number of measurements.
909    pub m_rows: usize,
910    /// Signal length.
911    pub n_cols: usize,
912    /// Matrix construction type.
913    pub matrix_type: CsMatrixType,
914    /// The actual measurement matrix stored row-major.
915    pub matrix: Vec<Vec<f32>>,
916}
917
918impl CompressedSensingMatrix {
919    /// Construct a new compressed sensing matrix.
920    pub fn new(m_rows: usize, n_cols: usize, matrix_type: CsMatrixType, rng: &mut impl Rng) -> Self {
921        let matrix = match &matrix_type {
922            CsMatrixType::Gaussian => {
923                let scale = 1.0 / (m_rows as f32).sqrt();
924                (0..m_rows).map(|_| {
925                    (0..n_cols).map(|_| box_muller(rng) * scale).collect()
926                }).collect()
927            }
928            CsMatrixType::Bernoulli => {
929                let scale = 1.0 / (m_rows as f32).sqrt();
930                (0..m_rows).map(|_| {
931                    (0..n_cols).map(|_| if rng.random::<f32>() > 0.5 { scale } else { -scale }).collect()
932                }).collect()
933            }
934            CsMatrixType::SubsampledHadamard => {
935                // Build a Hadamard matrix of size n_cols (next power of 2)
936                let nh = n_cols.next_power_of_two();
937                let mut h = vec![vec![1.0_f32; nh]; nh];
938                let mut step = 1usize;
939                while step < nh {
940                    for i in (0..nh).step_by(2 * step) {
941                        for j in i..(i + step).min(nh) {
942                            let a = h[j][0..nh].to_vec();
943                            let b = h[j + step][0..nh].to_vec();
944                            for k in 0..nh {
945                                h[j][k] = a[k] + b[k];
946                                h[j + step][k] = a[k] - b[k];
947                            }
948                        }
949                    }
950                    step *= 2;
951                }
952                let scale = 1.0 / (m_rows as f32 * nh as f32).sqrt();
953                // Subsample m_rows rows at random
954                let mut row_indices: Vec<usize> = (0..nh).collect();
955                for i in 0..m_rows.min(nh) {
956                    let j = i + rng.random_range(0..(nh - i));
957                    row_indices.swap(i, j);
958                }
959                row_indices.truncate(m_rows.min(nh));
960                row_indices.iter().map(|&r| {
961                    h[r][..n_cols].iter().map(|&v| v * scale).collect()
962                }).collect()
963            }
964        };
965        Self { m_rows, n_cols, matrix_type, matrix }
966    }
967
968    /// Apply the measurement matrix: y = Φ x.
969    pub fn measure(&self, x: &[f32]) -> Vec<f32> {
970        self.matrix.iter().map(|row| dot(row, &x[..x.len().min(self.n_cols)])).collect()
971    }
972
973    /// Compute Φ^T v.
974    pub fn transpose_apply(&self, v: &[f32]) -> Vec<f32> {
975        let mut result = vec![0.0_f32; self.n_cols];
976        for (row, &vi) in self.matrix.iter().zip(v.iter()) {
977            for (r, &a) in result.iter_mut().zip(row.iter()) {
978                *r += a * vi;
979            }
980        }
981        result
982    }
983
984    /// Theoretical sufficient number of measurements for s-sparse recovery.
985    ///
986    /// Based on m ≥ C · s · log(n / s) with constant C ≈ 4.
987    pub fn sufficient_measurements(n: usize, s: usize) -> usize {
988        if s == 0 || n == 0 {
989            return 1;
990        }
991        let c = 4.0_f32;
992        (c * s as f32 * (n as f32 / s as f32).ln()).ceil() as usize
993    }
994}
995
996// ── BasisPursuitDenoise ───────────────────────────────────────────────────────
997
998/// BPDN solver: min ‖x‖₁ s.t. ‖Ax - b‖₂ ≤ σ.
999///
1000/// Implemented via ADMM with an augmented Lagrangian formulation equivalent
1001/// to LASSO with parameter λ = 1/σ (approximately).
1002pub struct BasisPursuitDenoise {
1003    /// Noise tolerance σ.
1004    pub sigma: f32,
1005    /// ADMM penalty parameter.
1006    pub rho: f32,
1007    /// Maximum ADMM iterations.
1008    pub max_iter: usize,
1009    /// Convergence tolerance.
1010    pub tol: f32,
1011}
1012
1013impl BasisPursuitDenoise {
1014    /// Solve BPDN: minimize ‖x‖₁ subject to ‖Ax - b‖₂ ≤ σ.
1015    ///
1016    /// Uses ADMM with the Lagrangian: ½‖Ax - b‖₂² + λ‖z‖₁ + ρ/2‖x - z + u‖₂²
1017    /// where λ = 1 / (sigma * m).
1018    pub fn solve(&self, a: &[Vec<f32>], b: &[f32]) -> Vec<f32> {
1019        let m = a.len();
1020        let n = if m > 0 { a[0].len() } else { 0 };
1021        if n == 0 || m == 0 {
1022            return Vec::new();
1023        }
1024        let lambda = 1.0 / (self.sigma.max(1e-6) * m as f32);
1025        let rho = self.rho;
1026
1027        let mut x = vec![0.0_f32; n];
1028        let mut z = vec![0.0_f32; n];
1029        let mut u = vec![0.0_f32; n];
1030
1031        // A^T b
1032        let atb: Vec<f32> = (0..n).map(|j| {
1033            a.iter().zip(b.iter()).map(|(row, &bi)| row[j] * bi).sum::<f32>()
1034        }).collect();
1035
1036        for _ in 0..self.max_iter {
1037            // x-update: (A^T A + rho I) x = A^T b + rho (z - u)
1038            let rhs: Vec<f32> = (0..n).map(|i| atb[i] + rho * (z[i] - u[i])).collect();
1039            // CG solve for (A^T A + rho I) x = rhs
1040            x = bpdn_cg(a, rho, &rhs, &x, 50);
1041
1042            let z_old = z.clone();
1043            // z-update: soft threshold with lambda / rho
1044            for i in 0..n {
1045                z[i] = LassoEncoder::soft_threshold(x[i] + u[i], lambda / rho);
1046            }
1047            // u-update
1048            for i in 0..n {
1049                u[i] += x[i] - z[i];
1050            }
1051
1052            let primal_res: f32 = (0..n).map(|i| (x[i] - z[i]).powi(2)).sum::<f32>().sqrt();
1053            let dual_res: f32 = (0..n).map(|i| (rho * (z[i] - z_old[i])).powi(2)).sum::<f32>().sqrt();
1054            if primal_res < self.tol && dual_res < self.tol {
1055                break;
1056            }
1057        }
1058        z
1059    }
1060
1061    /// Residual norm ‖Ax - b‖₂.
1062    pub fn residual_norm(a: &[Vec<f32>], b: &[f32], x: &[f32]) -> f32 {
1063        a.iter().zip(b.iter()).map(|(row, &bi)| {
1064            let ax_i: f32 = dot(row, x);
1065            (ax_i - bi).powi(2)
1066        }).sum::<f32>().sqrt()
1067    }
1068}
1069
1070/// Conjugate gradient for (A^T A + rho I) x = b used in BPDN ADMM.
1071fn bpdn_cg(a: &[Vec<f32>], rho: f32, b: &[f32], x0: &[f32], max_iter: usize) -> Vec<f32> {
1072    let n = b.len();
1073    let mut x = x0.to_vec();
1074    // Compute r = b - (A^T A + rho I) x0
1075    let ax: Vec<f32> = a.iter().map(|row| dot(row, &x)).collect();
1076    let atax: Vec<f32> = (0..n).map(|j| a.iter().zip(ax.iter()).map(|(row, &ai)| row[j] * ai).sum::<f32>()).collect();
1077    let mut r: Vec<f32> = (0..n).map(|i| b[i] - atax[i] - rho * x[i]).collect();
1078    let mut p = r.clone();
1079    let mut rsold: f32 = r.iter().map(|&v| v * v).sum();
1080
1081    for _ in 0..max_iter {
1082        if rsold < 1e-12 { break; }
1083        let ap: Vec<f32> = a.iter().map(|row| dot(row, &p)).collect();
1084        let atap: Vec<f32> = (0..n).map(|j| a.iter().zip(ap.iter()).map(|(row, &ai)| row[j] * ai).sum::<f32>()).collect();
1085        let ap_full: Vec<f32> = (0..n).map(|i| atap[i] + rho * p[i]).collect();
1086        let denom: f32 = p.iter().zip(ap_full.iter()).map(|(&pi, &api)| pi * api).sum();
1087        if denom.abs() < 1e-14 { break; }
1088        let alpha = rsold / denom;
1089        for i in 0..n { x[i] += alpha * p[i]; r[i] -= alpha * ap_full[i]; }
1090        let rsnew: f32 = r.iter().map(|&v| v * v).sum();
1091        let beta = rsnew / rsold.max(1e-14);
1092        for i in 0..n { p[i] = r[i] + beta * p[i]; }
1093        rsold = rsnew;
1094    }
1095    x
1096}
1097
1098// ── RecoveryGuarantees ────────────────────────────────────────────────────────
1099
1100/// Theoretical compressed sensing recovery guarantees and RIP analysis.
1101pub struct RecoveryGuarantees;
1102
1103impl RecoveryGuarantees {
1104    /// Estimate the RIP-s constant δ_s via random s-sparse test vectors.
1105    ///
1106    /// δ_s = max over s-sparse unit vectors x: |‖Ax‖₂² - 1|
1107    pub fn rip_constant(
1108        a: &[Vec<f32>],
1109        s: usize,
1110        n_trials: usize,
1111        rng: &mut impl Rng,
1112    ) -> f32 {
1113        let n = if a.is_empty() { 0 } else { a[0].len() };
1114        if n == 0 { return 0.0; }
1115        let mut max_delta = 0.0_f32;
1116
1117        for _ in 0..n_trials {
1118            // Random s-sparse unit vector
1119            let mut support: Vec<usize> = (0..n).collect();
1120            for i in 0..s.min(n) {
1121                let j = i + rng.random_range(0..(n - i));
1122                support.swap(i, j);
1123            }
1124            let support = &support[..s.min(n)];
1125
1126            let mut x = vec![0.0_f32; n];
1127            let mut sq_sum = 0.0_f32;
1128            for &i in support {
1129                let v = box_muller(rng);
1130                x[i] = v;
1131                sq_sum += v * v;
1132            }
1133            if sq_sum < 1e-10 { continue; }
1134            let x_norm = sq_sum.sqrt();
1135            for v in x.iter_mut() { *v /= x_norm; }
1136
1137            let ax: Vec<f32> = a.iter().map(|row| dot(row, &x)).collect();
1138            let ax_norm_sq: f32 = ax.iter().map(|&v| v * v).sum();
1139            let delta = (ax_norm_sq - 1.0).abs();
1140            if delta > max_delta { max_delta = delta; }
1141        }
1142        max_delta
1143    }
1144
1145    /// Phase transition boundary: minimum measurement ratio m/n for s/n sparsity ratio.
1146    ///
1147    /// Based on Donoho-Tanner phase transition (approximation).
1148    /// Returns the minimum undersampling ratio m/n needed to recover s-sparse signals.
1149    pub fn phase_transition(sparsity_ratio: f32) -> f32 {
1150        // Empirical approximation to the Donoho-Tanner curve
1151        // δ* ≈ f(ρ) where ρ = s/n, δ = m/n
1152        let rho = sparsity_ratio.clamp(0.0, 1.0);
1153        // Gaussian approximation: δ ≈ 2ρ log(1/ρ) + rho*(1 + log(2π)) for small ρ
1154        if rho < 1e-6 {
1155            return 0.0;
1156        }
1157        let delta = 2.0 * rho * (1.0 / rho).ln() + rho * (1.0 + (2.0 * std::f32::consts::PI).ln());
1158        delta.min(1.0)
1159    }
1160
1161    /// Check if the RIP condition is sufficient for exact recovery.
1162    ///
1163    /// Basis Pursuit guarantees exact recovery if δ_{2s} < √2 - 1 ≈ 0.4142.
1164    pub fn is_rip_sufficient(delta_2s: f32) -> bool {
1165        delta_2s < (2.0_f32.sqrt() - 1.0)
1166    }
1167}
1168
1169// ── CsMetrics ─────────────────────────────────────────────────────────────────
1170
1171/// Metrics for evaluating compressed sensing recovery quality.
1172pub struct CsMetrics;
1173
1174impl CsMetrics {
1175    /// Recovery signal-to-noise ratio in dB: 10 log10(‖x‖² / ‖x - x̂‖²).
1176    pub fn recovery_snr(original: &[f32], recovered: &[f32]) -> f32 {
1177        let signal_power: f32 = original.iter().map(|&v| v * v).sum();
1178        let error_power: f32 = original.iter().zip(recovered.iter()).map(|(a, b)| (a - b).powi(2)).sum();
1179        if error_power < 1e-14 { return 100.0; }
1180        if signal_power < 1e-14 { return -100.0; }
1181        10.0 * (signal_power / error_power).log10()
1182    }
1183
1184    /// Support recovery rate: fraction of true support indices that are recovered.
1185    ///
1186    /// `true_support`: indices of non-zero entries in the original signal.
1187    /// `recovered`: the recovered signal vector.
1188    /// `threshold`: values above this are considered non-zero in the recovery.
1189    pub fn support_recovery_rate(true_support: &[usize], recovered: &[f32], threshold: f32) -> f32 {
1190        if true_support.is_empty() { return 1.0; }
1191        let recovered_support: Vec<usize> = recovered.iter().enumerate()
1192            .filter(|(_, &v)| v.abs() > threshold)
1193            .map(|(i, _)| i)
1194            .collect();
1195        let hits = true_support.iter().filter(|&&i| recovered_support.contains(&i)).count();
1196        hits as f32 / true_support.len() as f32
1197    }
1198
1199    /// Exact support recovery: true if the recovered support exactly matches the true support.
1200    pub fn exact_support_recovery(true_support: &[usize], recovered: &[f32], threshold: f32) -> bool {
1201        let mut recovered_support: Vec<usize> = recovered.iter().enumerate()
1202            .filter(|(_, &v)| v.abs() > threshold)
1203            .map(|(i, _)| i)
1204            .collect();
1205        let mut true_sorted = true_support.to_vec();
1206        true_sorted.sort_unstable();
1207        recovered_support.sort_unstable();
1208        true_sorted == recovered_support
1209    }
1210
1211    /// Normalized recovery error ‖x - x̂‖₂ / ‖x‖₂.
1212    pub fn normalized_error(original: &[f32], recovered: &[f32]) -> f32 {
1213        let orig_norm: f32 = original.iter().map(|&v| v * v).sum::<f32>().sqrt();
1214        let err_norm: f32 = original.iter().zip(recovered.iter()).map(|(a, b)| (a - b).powi(2)).sum::<f32>().sqrt();
1215        if orig_norm < 1e-10 { return err_norm; }
1216        err_norm / orig_norm
1217    }
1218
1219    /// Sparsity of the recovered signal (fraction of entries below threshold).
1220    pub fn recovered_sparsity(recovered: &[f32], threshold: f32) -> f32 {
1221        if recovered.is_empty() { return 0.0; }
1222        let zeros = recovered.iter().filter(|&&v| v.abs() <= threshold).count();
1223        zeros as f32 / recovered.len() as f32
1224    }
1225}
1226
1227// ── Helpers for advanced module ────────────────────────────────────────────────
1228
1229/// Initialize random key/query/value matrices for attention tests.
1230pub fn init_attention_matrices(
1231    seq_len: usize,
1232    head_dim: usize,
1233    rng: &mut impl Rng,
1234) -> (Vec<Vec<f32>>, Vec<Vec<f32>>, Vec<Vec<f32>>) {
1235    let scale = (1.0 / head_dim as f32).sqrt();
1236    let q: Vec<Vec<f32>> = (0..seq_len)
1237        .map(|_| (0..head_dim).map(|_| box_muller(rng) * scale).collect())
1238        .collect();
1239    let k: Vec<Vec<f32>> = (0..seq_len)
1240        .map(|_| (0..head_dim).map(|_| box_muller(rng) * scale).collect())
1241        .collect();
1242    let v: Vec<Vec<f32>> = (0..seq_len)
1243        .map(|_| (0..head_dim).map(|_| box_muller(rng) * scale).collect())
1244        .collect();
1245    (q, k, v)
1246}