1use scirs2_core::random::{rngs::StdRng, Rng, SeedableRng};
5use scirs2_core::RngExt;
6use super::{box_muller, dot, norm, LassoEncoder};
7
8pub struct BigBirdAttention {
17 pub seq_len: usize,
19 pub n_heads: usize,
21 pub head_dim: usize,
23 pub block_size: usize,
25 pub n_random: usize,
27 pub n_global: usize,
29}
30
31impl BigBirdAttention {
32 pub fn compute_attention_mask(&self, rng: &mut impl Rng) -> Vec<Vec<bool>> {
37 let n = self.seq_len;
38 let mut mask = vec![vec![false; n]; n];
39
40 let bs = self.block_size.max(1);
42 for i in 0..n {
43 let block_i = i / bs;
44 let start_block = block_i.saturating_sub(1);
46 let end_block = (block_i + 2).min((n + bs - 1) / bs);
47 for b in start_block..end_block {
48 let start_tok = b * bs;
49 let end_tok = ((b + 1) * bs).min(n);
50 for j in start_tok..end_tok {
51 mask[i][j] = true;
52 }
53 }
54 }
55
56 for i in 0..n {
58 for _ in 0..self.n_random {
59 let j = rng.random_range(0..n);
60 mask[i][j] = true;
61 }
62 }
63
64 for g in 0..self.n_global.min(n) {
66 for j in 0..n {
67 mask[g][j] = true;
68 mask[j][g] = true;
69 }
70 }
71
72 mask
73 }
74
75 pub fn sparse_attention(
80 &self,
81 q: &[Vec<f32>],
82 k: &[Vec<f32>],
83 v: &[Vec<f32>],
84 mask: &[Vec<bool>],
85 ) -> Vec<Vec<f32>> {
86 let n = self.seq_len.min(q.len());
87 let d = self.head_dim as f32;
88 let scale = 1.0 / d.sqrt();
89 let mut output = vec![vec![0.0_f32; self.head_dim]; n];
90
91 for i in 0..n {
92 let allowed: Vec<usize> = (0..n).filter(|&j| mask[i].get(j).copied().unwrap_or(false)).collect();
94 if allowed.is_empty() {
95 continue;
96 }
97
98 let scores: Vec<f32> = allowed
100 .iter()
101 .map(|&j| dot(&q[i], &k[j]) * scale)
102 .collect();
103
104 let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
106 let exp_scores: Vec<f32> = scores.iter().map(|&s| (s - max_score).exp()).collect();
107 let sum_exp: f32 = exp_scores.iter().sum::<f32>().max(1e-9);
108
109 for (pos, &j) in allowed.iter().enumerate() {
111 let weight = exp_scores[pos] / sum_exp;
112 for (out, &vval) in output[i].iter_mut().zip(v[j].iter()) {
113 *out += weight * vval;
114 }
115 }
116 }
117 output
118 }
119
120 pub fn mask_sparsity(mask: &[Vec<bool>]) -> f32 {
122 let total = (mask.len() * mask.first().map(|r| r.len()).unwrap_or(0)) as f32;
123 if total < 1.0 {
124 return 0.0;
125 }
126 let active: usize = mask.iter().flat_map(|row| row.iter()).filter(|&&b| b).count();
127 active as f32 / total
128 }
129}
130
131pub struct SparseSlidingWindowAttention {
138 pub seq_len: usize,
140 pub n_heads: usize,
142 pub head_dim: usize,
144 pub window_size: usize,
146 pub global_tokens: Vec<usize>,
148}
149
150impl SparseSlidingWindowAttention {
151 pub fn compute_attention_mask(&self) -> Vec<Vec<bool>> {
153 let n = self.seq_len;
154 let mut mask = vec![vec![false; n]; n];
155
156 for i in 0..n {
158 let start = i.saturating_sub(self.window_size);
159 let end = (i + self.window_size + 1).min(n);
160 for j in start..end {
161 mask[i][j] = true;
162 }
163 }
164
165 for &g in &self.global_tokens {
167 if g < n {
168 for j in 0..n {
169 mask[g][j] = true;
170 mask[j][g] = true;
171 }
172 }
173 }
174
175 mask
176 }
177
178 pub fn local_attention(
182 &self,
183 q: &[Vec<f32>],
184 k: &[Vec<f32>],
185 v: &[Vec<f32>],
186 ) -> Vec<Vec<f32>> {
187 let n = self.seq_len.min(q.len());
188 let scale = 1.0 / (self.head_dim as f32).sqrt();
189 let global_set: std::collections::HashSet<usize> = self.global_tokens.iter().cloned().collect();
190 let mut output = vec![vec![0.0_f32; self.head_dim]; n];
191
192 for i in 0..n {
193 let mut attend_to: Vec<usize> = Vec::new();
194
195 let start = i.saturating_sub(self.window_size);
197 let end = (i + self.window_size + 1).min(n);
198 for j in start..end {
199 attend_to.push(j);
200 }
201
202 for &g in &self.global_tokens {
204 if g < n && !attend_to.contains(&g) {
205 attend_to.push(g);
206 }
207 }
208
209 if global_set.contains(&i) {
211 attend_to = (0..n).collect();
212 }
213
214 let scores: Vec<f32> = attend_to.iter().map(|&j| dot(&q[i], &k[j]) * scale).collect();
215 let max_s = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
216 let exps: Vec<f32> = scores.iter().map(|&s| (s - max_s).exp()).collect();
217 let sum_e = exps.iter().sum::<f32>().max(1e-9);
218
219 for (pos, &j) in attend_to.iter().enumerate() {
220 let w = exps[pos] / sum_e;
221 for (o, &vv) in output[i].iter_mut().zip(v[j].iter()) {
222 *o += w * vv;
223 }
224 }
225 }
226 output
227 }
228}
229
230pub struct SparseAttentionRouter {
237 pub seq_len: usize,
239 pub dim: usize,
241 pub top_k: usize,
243 pub router_w: Vec<Vec<f32>>,
245}
246
247impl SparseAttentionRouter {
248 pub fn new(seq_len: usize, dim: usize, top_k: usize, rng: &mut impl Rng) -> Self {
250 let scale = (2.0 / (dim + dim) as f32).sqrt();
251 let router_w: Vec<Vec<f32>> = (0..dim)
252 .map(|_| (0..dim).map(|_| box_muller(rng) * scale).collect())
253 .collect();
254 Self { seq_len, dim, top_k, router_w }
255 }
256
257 pub fn route(&self, queries: &[Vec<f32>], keys: &[Vec<f32>]) -> Vec<(Vec<usize>, Vec<f32>)> {
261 let n_q = queries.len().min(self.seq_len);
262 let n_k = keys.len();
263 let k = self.top_k.min(n_k);
264
265 queries[..n_q].iter().map(|q| {
266 let q_proj: Vec<f32> = self.router_w.iter().map(|row| dot(row, q)).collect();
268
269 let mut scores: Vec<(usize, f32)> = keys.iter().enumerate().map(|(j, key)| {
271 let score = dot(&q_proj, key) / (self.dim as f32).sqrt();
272 (j, score)
273 }).collect();
274
275 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
277 scores.truncate(k);
278
279 let max_s = scores.iter().map(|(_, s)| *s).fold(f32::NEG_INFINITY, f32::max);
281 let exps: Vec<f32> = scores.iter().map(|(_, s)| (s - max_s).exp()).collect();
282 let sum_e = exps.iter().sum::<f32>().max(1e-9);
283 let weights: Vec<f32> = exps.iter().map(|e| e / sum_e).collect();
284 let indices: Vec<usize> = scores.iter().map(|(i, _)| *i).collect();
285
286 (indices, weights)
287 }).collect()
288 }
289
290 pub fn attend(&self, queries: &[Vec<f32>], keys: &[Vec<f32>], values: &[Vec<f32>]) -> Vec<Vec<f32>> {
292 let routing = self.route(queries, keys);
293 routing.into_iter().map(|(indices, weights)| {
294 let mut out = vec![0.0_f32; self.dim];
295 for (idx, w) in indices.iter().zip(weights.iter()) {
296 if *idx < values.len() {
297 for (o, &v) in out.iter_mut().zip(values[*idx].iter()) {
298 *o += w * v;
299 }
300 }
301 }
302 out
303 }).collect()
304 }
305}
306
307pub struct SparsePositionEncoding {
314 pub max_seq_len: usize,
316 pub dim: usize,
318 pub n_buckets: usize,
320 pub bucket_embeddings: Vec<Vec<f32>>,
322}
323
324impl SparsePositionEncoding {
325 pub fn new(max_seq_len: usize, dim: usize, n_buckets: usize, rng: &mut impl Rng) -> Self {
327 let scale = (1.0 / dim as f32).sqrt();
328 let bucket_embeddings = (0..n_buckets)
329 .map(|_| (0..dim).map(|_| box_muller(rng) * scale).collect())
330 .collect();
331 Self { max_seq_len, dim, n_buckets, bucket_embeddings }
332 }
333
334 pub fn relative_position_bucket(&self, rel_pos: i32) -> usize {
336 let abs_pos = rel_pos.unsigned_abs() as usize;
337 let half = self.n_buckets / 2;
338 if abs_pos < half {
339 abs_pos.min(half.saturating_sub(1))
341 } else {
342 let log_bucket = (abs_pos as f32 / half as f32).ln() / (self.max_seq_len as f32 / half as f32).ln().max(1e-8);
344 let bucket = half + (log_bucket * half as f32) as usize;
345 bucket.min(self.n_buckets - 1)
346 }
347 }
348
349 pub fn position_bias_indices(&self, seq_len: usize) -> Vec<usize> {
352 let n = seq_len.min(self.max_seq_len);
353 let mut indices = vec![0usize; n * n];
354 for i in 0..n {
355 for j in 0..n {
356 let rel = j as i32 - i as i32;
357 indices[i * n + j] = self.relative_position_bucket(rel);
358 }
359 }
360 indices
361 }
362
363 pub fn get_position_biases(&self, seq_len: usize) -> Vec<Vec<f32>> {
365 let indices = self.position_bias_indices(seq_len);
366 indices.iter().map(|&b| self.bucket_embeddings[b].clone()).collect()
367 }
368}
369
370pub struct GroupLasso {
377 pub groups: Vec<Vec<usize>>,
379 pub lambda: f32,
381 pub learning_rate: f32,
383 pub max_iter: usize,
385}
386
387impl GroupLasso {
388 pub fn fit(&self, x_data: &[Vec<f32>], y: &[f32]) -> Vec<f32> {
390 if x_data.is_empty() || y.is_empty() {
391 return Vec::new();
392 }
393 let n_features = x_data[0].len();
394 let n_samples = x_data.len();
395 let mut w = vec![0.0_f32; n_features];
396
397 for _ in 0..self.max_iter {
398 let preds: Vec<f32> = x_data.iter().map(|row| dot(row, &w)).collect();
400 let residuals: Vec<f32> = preds.iter().zip(y.iter()).map(|(&p, &yi)| p - yi).collect();
401 let mut grad = vec![0.0_f32; n_features];
402 for (i, row) in x_data.iter().enumerate() {
403 for (j, &xij) in row.iter().enumerate() {
404 grad[j] += xij * residuals[i] / n_samples as f32;
405 }
406 }
407
408 let mut w_new = w.clone();
410 for j in 0..n_features {
411 w_new[j] -= self.learning_rate * grad[j];
412 }
413
414 for group in &self.groups {
416 let group_norm: f32 = group.iter()
417 .filter(|&&j| j < n_features)
418 .map(|&j| w_new[j] * w_new[j])
419 .sum::<f32>()
420 .sqrt();
421 let threshold = self.lambda * self.learning_rate;
422 if group_norm < threshold {
423 for &j in group.iter().filter(|&&j| j < n_features) {
424 w_new[j] = 0.0;
425 }
426 } else {
427 let scale = 1.0 - threshold / group_norm;
428 for &j in group.iter().filter(|&&j| j < n_features) {
429 w_new[j] *= scale;
430 }
431 }
432 }
433 w = w_new;
434 }
435 w
436 }
437
438 pub fn predict(&self, x_data: &[Vec<f32>], w: &[f32]) -> Vec<f32> {
440 x_data.iter().map(|row| dot(row, w)).collect()
441 }
442
443 pub fn group_sparsity(&self, w: &[f32]) -> f32 {
445 if self.groups.is_empty() {
446 return 0.0;
447 }
448 let zero_groups = self.groups.iter().filter(|group| {
449 group.iter().all(|&j| j >= w.len() || w[j].abs() < 1e-8)
450 }).count();
451 zero_groups as f32 / self.groups.len() as f32
452 }
453}
454
455pub struct StructuredPruningMask {
463 pub n_keep: usize,
465 pub group_size: usize,
467}
468
469impl StructuredPruningMask {
470 pub fn nm_24() -> Self {
472 Self { n_keep: 2, group_size: 4 }
473 }
474
475 pub fn apply(&self, weights: &[f32]) -> (Vec<f32>, Vec<bool>) {
480 let n = weights.len();
481 let mut pruned = weights.to_vec();
482 let mut mask = vec![false; n];
483 let m = self.group_size;
484 let k = self.n_keep.min(m);
485
486 let n_groups = (n + m - 1) / m;
487 for g in 0..n_groups {
488 let start = g * m;
489 let end = (start + m).min(n);
490 let group_len = end - start;
491
492 let mut indices: Vec<usize> = (start..end).collect();
494 indices.sort_by(|&a, &b| {
495 weights[b].abs().partial_cmp(&weights[a].abs()).unwrap_or(std::cmp::Ordering::Equal)
496 });
497
498 let keep_count = k.min(group_len);
500 for i in 0..group_len {
501 if i < keep_count {
502 mask[indices[i]] = true;
503 } else {
504 pruned[indices[i]] = 0.0;
505 }
506 }
507 }
508 (pruned, mask)
509 }
510
511 pub fn sparsity(mask: &[bool]) -> f32 {
513 if mask.is_empty() {
514 return 0.0;
515 }
516 let zeros = mask.iter().filter(|&&b| !b).count();
517 zeros as f32 / mask.len() as f32
518 }
519}
520
521#[derive(Debug, Clone)]
530pub enum ChannelImportanceCriterion {
531 L1Norm,
533 TaylorExpansion,
535 Fpgm,
537}
538
539pub struct ChannelPruner {
541 pub prune_ratio: f32,
543 pub criterion: ChannelImportanceCriterion,
545}
546
547impl ChannelPruner {
548 pub fn compute_importance(
553 &self,
554 filters: &[Vec<f32>],
555 gradients: Option<&[Vec<f32>]>,
556 ) -> Vec<f32> {
557 match &self.criterion {
558 ChannelImportanceCriterion::L1Norm => {
559 filters.iter().map(|f| f.iter().map(|w| w.abs()).sum::<f32>()).collect()
560 }
561 ChannelImportanceCriterion::TaylorExpansion => {
562 let grads = gradients.unwrap_or(filters);
563 filters.iter().zip(grads.iter().chain(std::iter::repeat(&filters[0]))).map(|(f, g)| {
564 f.iter().zip(g.iter()).map(|(w, dw)| (w * dw).abs()).sum::<f32>()
565 }).collect()
566 }
567 ChannelImportanceCriterion::Fpgm => {
568 let n = filters.len();
571 (0..n).map(|i| {
572 let sum_dist: f32 = (0..n).filter(|&j| j != i).map(|j| {
575 filters[i].iter().zip(filters[j].iter()).map(|(a, b)| (a - b).powi(2)).sum::<f32>().sqrt()
576 }).sum();
577 if n <= 1 { norm(&filters[i]) } else { sum_dist / (n - 1) as f32 }
578 }).collect()
579 }
580 }
581 }
582
583 pub fn prune_mask(&self, importance: &[f32]) -> Vec<bool> {
587 let n = importance.len();
588 let n_prune = ((n as f32 * self.prune_ratio).round() as usize).min(n);
589 let mut indexed: Vec<(usize, f32)> = importance.iter().cloned().enumerate().collect();
590 indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
591 let mut mask = vec![true; n];
592 for (idx, _) in indexed.iter().take(n_prune) {
593 mask[*idx] = false;
594 }
595 mask
596 }
597
598 pub fn apply_pruning(&self, filters: &[Vec<f32>], mask: &[bool]) -> Vec<Vec<f32>> {
600 filters.iter().zip(mask.iter()).map(|(f, &keep)| {
601 if keep { f.clone() } else { vec![0.0_f32; f.len()] }
602 }).collect()
603 }
604}
605
606pub struct LayerPruner {
614 pub importance_threshold: f32,
616}
617
618impl LayerPruner {
619 pub fn layer_importance(&self, weights: &[f32], gradients: &[f32]) -> f32 {
624 if weights.is_empty() {
625 return 0.0;
626 }
627 let n = weights.len().min(gradients.len());
628 let fisher: f32 = (0..n).map(|i| (gradients[i] * weights[i]).powi(2)).sum::<f32>() / n as f32;
630 fisher
631 }
632
633 pub fn select_layers_to_prune(&self, importances: &[f32]) -> Vec<bool> {
637 importances.iter().map(|&imp| imp >= self.importance_threshold).collect()
638 }
639
640 pub fn rank_layers(&self, importances: &[f32]) -> Vec<usize> {
644 let mut indexed: Vec<(usize, f32)> = importances.iter().cloned().enumerate().collect();
645 indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
646 indexed.iter().map(|(i, _)| *i).collect()
647 }
648}
649
650pub struct SparseCodingLayer {
658 pub input_dim: usize,
660 pub dict_size: usize,
662 pub n_steps: usize,
664 pub lambda: f32,
666 pub dictionary: Vec<Vec<f32>>,
668 pub step_sizes: Vec<f32>,
670 pub we_matrix: Vec<Vec<f32>>,
672}
673
674impl SparseCodingLayer {
675 pub fn new(input_dim: usize, dict_size: usize, n_steps: usize, lambda: f32, rng: &mut impl Rng) -> Self {
677 let scale = (2.0 / (input_dim + dict_size) as f32).sqrt();
678 let dictionary: Vec<Vec<f32>> = (0..dict_size)
679 .map(|_| (0..input_dim).map(|_| box_muller(rng) * scale).collect())
680 .collect();
681
682 let dictionary: Vec<Vec<f32>> = dictionary.into_iter().map(|mut atom| {
684 let n = norm(&atom);
685 if n > 1e-8 { for v in atom.iter_mut() { *v /= n; } }
686 atom
687 }).collect();
688
689 let default_step = 0.1_f32;
691 let step_sizes = vec![default_step; n_steps];
692
693 let we_matrix = (0..dict_size)
695 .map(|i| (0..dict_size).map(|j| if i == j { 1.0 - default_step } else { 0.0 }).collect())
696 .collect();
697
698 Self { input_dim, dict_size, n_steps, lambda, dictionary, step_sizes, we_matrix }
699 }
700
701 pub fn forward(&self, x: &[f32]) -> Vec<f32> {
705 let mut z: Vec<f32> = self.dictionary.iter()
707 .map(|atom| dot(atom, x))
708 .collect();
709
710 let init_thresh = self.lambda * self.step_sizes.first().copied().unwrap_or(0.1);
712 for v in z.iter_mut() {
713 *v = LassoEncoder::soft_threshold(*v, init_thresh);
714 }
715
716 for step in 0..self.n_steps {
718 let step_size = self.step_sizes.get(step).copied().unwrap_or(0.1);
719 let threshold = self.lambda * step_size;
720
721 let we_z: Vec<f32> = self.we_matrix.iter().map(|row| dot(row, &z)).collect();
723 let dtx: Vec<f32> = self.dictionary.iter().map(|atom| dot(atom, x) * step_size).collect();
724
725 let mut z_new: Vec<f32> = (0..self.dict_size).map(|i| we_z[i] + dtx[i]).collect();
726 for v in z_new.iter_mut() {
727 *v = LassoEncoder::soft_threshold(*v, threshold);
728 }
729 z = z_new;
730 }
731 z
732 }
733
734 pub fn reconstruct(&self, z: &[f32]) -> Vec<f32> {
736 let mut x_hat = vec![0.0_f32; self.input_dim];
737 for (atom, &zi) in self.dictionary.iter().zip(z.iter()) {
738 for (xh, &a) in x_hat.iter_mut().zip(atom.iter()) {
739 *xh += zi * a;
740 }
741 }
742 x_hat
743 }
744
745 pub fn reconstruction_loss(&self, x: &[f32]) -> f32 {
747 let z = self.forward(x);
748 let x_hat = self.reconstruct(&z);
749 x.iter().zip(x_hat.iter()).map(|(a, b)| (a - b).powi(2)).sum::<f32>()
750 / self.input_dim as f32
751 }
752}
753
754pub struct ListaNetwork {
761 pub input_dim: usize,
763 pub code_dim: usize,
765 pub n_layers: usize,
767 pub thresholds: Vec<f32>,
769 pub wd: Vec<Vec<f32>>,
771 pub we: Vec<Vec<f32>>,
773}
774
775impl ListaNetwork {
776 pub fn new(input_dim: usize, code_dim: usize, n_layers: usize, lambda: f32, rng: &mut impl Rng) -> Self {
778 let scale_wd = (2.0 / (input_dim + code_dim) as f32).sqrt();
779 let scale_we = (2.0 / (code_dim + code_dim) as f32).sqrt();
780 let wd = (0..code_dim).map(|_| (0..input_dim).map(|_| box_muller(rng) * scale_wd).collect()).collect();
781 let we = (0..code_dim).map(|_| (0..code_dim).map(|_| box_muller(rng) * scale_we).collect()).collect();
782 let thresholds = vec![lambda; n_layers];
783 Self { input_dim, code_dim, n_layers, thresholds, wd, we }
784 }
785
786 pub fn forward(&self, x: &[f32]) -> Vec<f32> {
790 let mut z = vec![0.0_f32; self.code_dim];
791 let init: Vec<f32> = self.wd.iter().map(|row| dot(row, x)).collect();
792 let thresh0 = self.thresholds.first().copied().unwrap_or(0.1);
793 for (zi, &init_i) in z.iter_mut().zip(init.iter()) {
794 *zi = LassoEncoder::soft_threshold(init_i, thresh0);
795 }
796
797 for layer in 1..self.n_layers {
798 let thresh = self.thresholds.get(layer).copied().unwrap_or(0.1);
799 let we_z: Vec<f32> = self.we.iter().map(|row| dot(row, &z)).collect();
800 let wd_x: Vec<f32> = self.wd.iter().map(|row| dot(row, x)).collect();
801 let mut z_new: Vec<f32> = (0..self.code_dim).map(|i| we_z[i] + wd_x[i]).collect();
802 for v in z_new.iter_mut() {
803 *v = LassoEncoder::soft_threshold(*v, thresh);
804 }
805 z = z_new;
806 }
807 z
808 }
809
810 pub fn output_sparsity(&self, x: &[f32]) -> f32 {
812 let z = self.forward(x);
813 let zeros = z.iter().filter(|&&v| v == 0.0).count();
814 zeros as f32 / self.code_dim as f32
815 }
816}
817
818pub struct PredictiveCodingLayer {
828 pub input_dim: usize,
830 pub hidden_dim: usize,
832 pub prediction_w: Vec<Vec<f32>>,
834 pub r_lr: f32,
836 pub n_inference_steps: usize,
838}
839
840impl PredictiveCodingLayer {
841 pub fn new(input_dim: usize, hidden_dim: usize, rng: &mut impl Rng) -> Self {
843 let scale = (1.0 / hidden_dim as f32).sqrt();
844 let prediction_w = (0..input_dim)
845 .map(|_| (0..hidden_dim).map(|_| box_muller(rng) * scale).collect())
846 .collect();
847 Self { input_dim, hidden_dim, prediction_w, r_lr: 0.1, n_inference_steps: 20 }
848 }
849
850 pub fn predict(&self, r: &[f32]) -> Vec<f32> {
852 self.prediction_w.iter().map(|row| dot(row, r)).collect()
853 }
854
855 pub fn prediction_error(&self, actual: &[f32], r: &[f32]) -> Vec<f32> {
857 let pred = self.predict(r);
858 actual.iter().zip(pred.iter()).map(|(a, p)| a - p).collect()
859 }
860
861 pub fn infer(&self, x: &[f32]) -> (Vec<f32>, Vec<f32>) {
866 let mut r = vec![0.0_f32; self.hidden_dim];
867
868 for _ in 0..self.n_inference_steps {
869 let e = self.prediction_error(x, &r);
870 let grad_r: Vec<f32> = (0..self.hidden_dim).map(|j| {
872 self.prediction_w.iter().zip(e.iter()).map(|(row, &ei)| row[j] * ei).sum::<f32>()
873 }).collect();
874 for (ri, &g) in r.iter_mut().zip(grad_r.iter()) {
875 *ri += self.r_lr * g;
876 }
877 }
878 let e_final = self.prediction_error(x, &r);
879 (r, e_final)
880 }
881
882 pub fn free_energy(&self, x: &[f32]) -> f32 {
884 let (_, e) = self.infer(x);
885 e.iter().map(|v| v * v).sum::<f32>() / self.input_dim as f32
886 }
887}
888
889#[derive(Debug, Clone)]
896pub enum CsMatrixType {
897 Gaussian,
899 Bernoulli,
901 SubsampledHadamard,
904}
905
906pub struct CompressedSensingMatrix {
908 pub m_rows: usize,
910 pub n_cols: usize,
912 pub matrix_type: CsMatrixType,
914 pub matrix: Vec<Vec<f32>>,
916}
917
918impl CompressedSensingMatrix {
919 pub fn new(m_rows: usize, n_cols: usize, matrix_type: CsMatrixType, rng: &mut impl Rng) -> Self {
921 let matrix = match &matrix_type {
922 CsMatrixType::Gaussian => {
923 let scale = 1.0 / (m_rows as f32).sqrt();
924 (0..m_rows).map(|_| {
925 (0..n_cols).map(|_| box_muller(rng) * scale).collect()
926 }).collect()
927 }
928 CsMatrixType::Bernoulli => {
929 let scale = 1.0 / (m_rows as f32).sqrt();
930 (0..m_rows).map(|_| {
931 (0..n_cols).map(|_| if rng.random::<f32>() > 0.5 { scale } else { -scale }).collect()
932 }).collect()
933 }
934 CsMatrixType::SubsampledHadamard => {
935 let nh = n_cols.next_power_of_two();
937 let mut h = vec![vec![1.0_f32; nh]; nh];
938 let mut step = 1usize;
939 while step < nh {
940 for i in (0..nh).step_by(2 * step) {
941 for j in i..(i + step).min(nh) {
942 let a = h[j][0..nh].to_vec();
943 let b = h[j + step][0..nh].to_vec();
944 for k in 0..nh {
945 h[j][k] = a[k] + b[k];
946 h[j + step][k] = a[k] - b[k];
947 }
948 }
949 }
950 step *= 2;
951 }
952 let scale = 1.0 / (m_rows as f32 * nh as f32).sqrt();
953 let mut row_indices: Vec<usize> = (0..nh).collect();
955 for i in 0..m_rows.min(nh) {
956 let j = i + rng.random_range(0..(nh - i));
957 row_indices.swap(i, j);
958 }
959 row_indices.truncate(m_rows.min(nh));
960 row_indices.iter().map(|&r| {
961 h[r][..n_cols].iter().map(|&v| v * scale).collect()
962 }).collect()
963 }
964 };
965 Self { m_rows, n_cols, matrix_type, matrix }
966 }
967
968 pub fn measure(&self, x: &[f32]) -> Vec<f32> {
970 self.matrix.iter().map(|row| dot(row, &x[..x.len().min(self.n_cols)])).collect()
971 }
972
973 pub fn transpose_apply(&self, v: &[f32]) -> Vec<f32> {
975 let mut result = vec![0.0_f32; self.n_cols];
976 for (row, &vi) in self.matrix.iter().zip(v.iter()) {
977 for (r, &a) in result.iter_mut().zip(row.iter()) {
978 *r += a * vi;
979 }
980 }
981 result
982 }
983
984 pub fn sufficient_measurements(n: usize, s: usize) -> usize {
988 if s == 0 || n == 0 {
989 return 1;
990 }
991 let c = 4.0_f32;
992 (c * s as f32 * (n as f32 / s as f32).ln()).ceil() as usize
993 }
994}
995
996pub struct BasisPursuitDenoise {
1003 pub sigma: f32,
1005 pub rho: f32,
1007 pub max_iter: usize,
1009 pub tol: f32,
1011}
1012
1013impl BasisPursuitDenoise {
1014 pub fn solve(&self, a: &[Vec<f32>], b: &[f32]) -> Vec<f32> {
1019 let m = a.len();
1020 let n = if m > 0 { a[0].len() } else { 0 };
1021 if n == 0 || m == 0 {
1022 return Vec::new();
1023 }
1024 let lambda = 1.0 / (self.sigma.max(1e-6) * m as f32);
1025 let rho = self.rho;
1026
1027 let mut x = vec![0.0_f32; n];
1028 let mut z = vec![0.0_f32; n];
1029 let mut u = vec![0.0_f32; n];
1030
1031 let atb: Vec<f32> = (0..n).map(|j| {
1033 a.iter().zip(b.iter()).map(|(row, &bi)| row[j] * bi).sum::<f32>()
1034 }).collect();
1035
1036 for _ in 0..self.max_iter {
1037 let rhs: Vec<f32> = (0..n).map(|i| atb[i] + rho * (z[i] - u[i])).collect();
1039 x = bpdn_cg(a, rho, &rhs, &x, 50);
1041
1042 let z_old = z.clone();
1043 for i in 0..n {
1045 z[i] = LassoEncoder::soft_threshold(x[i] + u[i], lambda / rho);
1046 }
1047 for i in 0..n {
1049 u[i] += x[i] - z[i];
1050 }
1051
1052 let primal_res: f32 = (0..n).map(|i| (x[i] - z[i]).powi(2)).sum::<f32>().sqrt();
1053 let dual_res: f32 = (0..n).map(|i| (rho * (z[i] - z_old[i])).powi(2)).sum::<f32>().sqrt();
1054 if primal_res < self.tol && dual_res < self.tol {
1055 break;
1056 }
1057 }
1058 z
1059 }
1060
1061 pub fn residual_norm(a: &[Vec<f32>], b: &[f32], x: &[f32]) -> f32 {
1063 a.iter().zip(b.iter()).map(|(row, &bi)| {
1064 let ax_i: f32 = dot(row, x);
1065 (ax_i - bi).powi(2)
1066 }).sum::<f32>().sqrt()
1067 }
1068}
1069
1070fn bpdn_cg(a: &[Vec<f32>], rho: f32, b: &[f32], x0: &[f32], max_iter: usize) -> Vec<f32> {
1072 let n = b.len();
1073 let mut x = x0.to_vec();
1074 let ax: Vec<f32> = a.iter().map(|row| dot(row, &x)).collect();
1076 let atax: Vec<f32> = (0..n).map(|j| a.iter().zip(ax.iter()).map(|(row, &ai)| row[j] * ai).sum::<f32>()).collect();
1077 let mut r: Vec<f32> = (0..n).map(|i| b[i] - atax[i] - rho * x[i]).collect();
1078 let mut p = r.clone();
1079 let mut rsold: f32 = r.iter().map(|&v| v * v).sum();
1080
1081 for _ in 0..max_iter {
1082 if rsold < 1e-12 { break; }
1083 let ap: Vec<f32> = a.iter().map(|row| dot(row, &p)).collect();
1084 let atap: Vec<f32> = (0..n).map(|j| a.iter().zip(ap.iter()).map(|(row, &ai)| row[j] * ai).sum::<f32>()).collect();
1085 let ap_full: Vec<f32> = (0..n).map(|i| atap[i] + rho * p[i]).collect();
1086 let denom: f32 = p.iter().zip(ap_full.iter()).map(|(&pi, &api)| pi * api).sum();
1087 if denom.abs() < 1e-14 { break; }
1088 let alpha = rsold / denom;
1089 for i in 0..n { x[i] += alpha * p[i]; r[i] -= alpha * ap_full[i]; }
1090 let rsnew: f32 = r.iter().map(|&v| v * v).sum();
1091 let beta = rsnew / rsold.max(1e-14);
1092 for i in 0..n { p[i] = r[i] + beta * p[i]; }
1093 rsold = rsnew;
1094 }
1095 x
1096}
1097
1098pub struct RecoveryGuarantees;
1102
1103impl RecoveryGuarantees {
1104 pub fn rip_constant(
1108 a: &[Vec<f32>],
1109 s: usize,
1110 n_trials: usize,
1111 rng: &mut impl Rng,
1112 ) -> f32 {
1113 let n = if a.is_empty() { 0 } else { a[0].len() };
1114 if n == 0 { return 0.0; }
1115 let mut max_delta = 0.0_f32;
1116
1117 for _ in 0..n_trials {
1118 let mut support: Vec<usize> = (0..n).collect();
1120 for i in 0..s.min(n) {
1121 let j = i + rng.random_range(0..(n - i));
1122 support.swap(i, j);
1123 }
1124 let support = &support[..s.min(n)];
1125
1126 let mut x = vec![0.0_f32; n];
1127 let mut sq_sum = 0.0_f32;
1128 for &i in support {
1129 let v = box_muller(rng);
1130 x[i] = v;
1131 sq_sum += v * v;
1132 }
1133 if sq_sum < 1e-10 { continue; }
1134 let x_norm = sq_sum.sqrt();
1135 for v in x.iter_mut() { *v /= x_norm; }
1136
1137 let ax: Vec<f32> = a.iter().map(|row| dot(row, &x)).collect();
1138 let ax_norm_sq: f32 = ax.iter().map(|&v| v * v).sum();
1139 let delta = (ax_norm_sq - 1.0).abs();
1140 if delta > max_delta { max_delta = delta; }
1141 }
1142 max_delta
1143 }
1144
1145 pub fn phase_transition(sparsity_ratio: f32) -> f32 {
1150 let rho = sparsity_ratio.clamp(0.0, 1.0);
1153 if rho < 1e-6 {
1155 return 0.0;
1156 }
1157 let delta = 2.0 * rho * (1.0 / rho).ln() + rho * (1.0 + (2.0 * std::f32::consts::PI).ln());
1158 delta.min(1.0)
1159 }
1160
1161 pub fn is_rip_sufficient(delta_2s: f32) -> bool {
1165 delta_2s < (2.0_f32.sqrt() - 1.0)
1166 }
1167}
1168
1169pub struct CsMetrics;
1173
1174impl CsMetrics {
1175 pub fn recovery_snr(original: &[f32], recovered: &[f32]) -> f32 {
1177 let signal_power: f32 = original.iter().map(|&v| v * v).sum();
1178 let error_power: f32 = original.iter().zip(recovered.iter()).map(|(a, b)| (a - b).powi(2)).sum();
1179 if error_power < 1e-14 { return 100.0; }
1180 if signal_power < 1e-14 { return -100.0; }
1181 10.0 * (signal_power / error_power).log10()
1182 }
1183
1184 pub fn support_recovery_rate(true_support: &[usize], recovered: &[f32], threshold: f32) -> f32 {
1190 if true_support.is_empty() { return 1.0; }
1191 let recovered_support: Vec<usize> = recovered.iter().enumerate()
1192 .filter(|(_, &v)| v.abs() > threshold)
1193 .map(|(i, _)| i)
1194 .collect();
1195 let hits = true_support.iter().filter(|&&i| recovered_support.contains(&i)).count();
1196 hits as f32 / true_support.len() as f32
1197 }
1198
1199 pub fn exact_support_recovery(true_support: &[usize], recovered: &[f32], threshold: f32) -> bool {
1201 let mut recovered_support: Vec<usize> = recovered.iter().enumerate()
1202 .filter(|(_, &v)| v.abs() > threshold)
1203 .map(|(i, _)| i)
1204 .collect();
1205 let mut true_sorted = true_support.to_vec();
1206 true_sorted.sort_unstable();
1207 recovered_support.sort_unstable();
1208 true_sorted == recovered_support
1209 }
1210
1211 pub fn normalized_error(original: &[f32], recovered: &[f32]) -> f32 {
1213 let orig_norm: f32 = original.iter().map(|&v| v * v).sum::<f32>().sqrt();
1214 let err_norm: f32 = original.iter().zip(recovered.iter()).map(|(a, b)| (a - b).powi(2)).sum::<f32>().sqrt();
1215 if orig_norm < 1e-10 { return err_norm; }
1216 err_norm / orig_norm
1217 }
1218
1219 pub fn recovered_sparsity(recovered: &[f32], threshold: f32) -> f32 {
1221 if recovered.is_empty() { return 0.0; }
1222 let zeros = recovered.iter().filter(|&&v| v.abs() <= threshold).count();
1223 zeros as f32 / recovered.len() as f32
1224 }
1225}
1226
1227pub fn init_attention_matrices(
1231 seq_len: usize,
1232 head_dim: usize,
1233 rng: &mut impl Rng,
1234) -> (Vec<Vec<f32>>, Vec<Vec<f32>>, Vec<Vec<f32>>) {
1235 let scale = (1.0 / head_dim as f32).sqrt();
1236 let q: Vec<Vec<f32>> = (0..seq_len)
1237 .map(|_| (0..head_dim).map(|_| box_muller(rng) * scale).collect())
1238 .collect();
1239 let k: Vec<Vec<f32>> = (0..seq_len)
1240 .map(|_| (0..head_dim).map(|_| box_muller(rng) * scale).collect())
1241 .collect();
1242 let v: Vec<Vec<f32>> = (0..seq_len)
1243 .map(|_| (0..head_dim).map(|_| box_muller(rng) * scale).collect())
1244 .collect();
1245 (q, k, v)
1246}