1use scirs2_core::random::{rngs::StdRng, Rng, SeedableRng};
21use scirs2_core::RngExt;
22use tenflowers_core::{Result, TensorError};
23
24mod helpers;
25use helpers::*;
26
27#[derive(Debug, Clone)]
33pub struct CpFactors {
34 pub factor_a: Vec<Vec<f64>>,
36 pub factor_b: Vec<Vec<f64>>,
38 pub factor_c: Vec<Vec<f64>>,
40 pub lambdas: Vec<f64>,
42 pub approx_error: f64,
44 pub iterations: usize,
46}
47
48#[derive(Debug, Clone)]
55pub struct EoCpDecomposition {
56 pub rank: usize,
57 pub max_iters: usize,
58 pub tol: f64,
59 pub seed: u64,
60}
61
62impl EoCpDecomposition {
63 pub fn new(rank: usize, max_iters: usize) -> Self {
65 Self {
66 rank,
67 max_iters,
68 tol: 1e-8,
69 seed: 42,
70 }
71 }
72
73 pub fn with_tol(mut self, tol: f64) -> Self {
75 self.tol = tol;
76 self
77 }
78
79 pub fn decompose(
82 &self,
83 tensor: &[f64],
84 dim_i: usize,
85 dim_j: usize,
86 dim_k: usize,
87 ) -> Result<CpFactors> {
88 let total = dim_i * dim_j * dim_k;
89 if tensor.len() != total {
90 return Err(TensorError::compute_error_simple(format!(
91 "CP decomposition: tensor length {} != {}*{}*{} = {}",
92 tensor.len(),
93 dim_i,
94 dim_j,
95 dim_k,
96 total,
97 )));
98 }
99
100 let r = self.rank;
101 let mut rng = StdRng::seed_from_u64(self.seed);
102
103 let mut a: Vec<Vec<f64>> = (0..dim_i)
105 .map(|_| (0..r).map(|_| rng.random_range(-1.0..1.0)).collect())
106 .collect();
107 let mut b: Vec<Vec<f64>> = (0..dim_j)
108 .map(|_| (0..r).map(|_| rng.random_range(-1.0..1.0)).collect())
109 .collect();
110 let mut c: Vec<Vec<f64>> = (0..dim_k)
111 .map(|_| (0..r).map(|_| rng.random_range(-1.0..1.0)).collect())
112 .collect();
113
114 let tensor_norm = frobenius(tensor);
115 let mut prev_error = f64::MAX;
116 let mut iterations = 0;
117
118 for iter in 0..self.max_iters {
119 iterations = iter + 1;
120
121 let kr_cb = khatri_rao(&c, &b);
124 let unfold_0 = self.mode_unfold(tensor, dim_i, dim_j, dim_k, 0);
125 a = self.update_factor(&unfold_0, &kr_cb, r)?;
126
127 let kr_ca = khatri_rao(&c, &a);
129 let unfold_1 = self.mode_unfold(tensor, dim_i, dim_j, dim_k, 1);
130 b = self.update_factor(&unfold_1, &kr_ca, r)?;
131
132 let kr_ba = khatri_rao(&b, &a);
134 let unfold_2 = self.mode_unfold(tensor, dim_i, dim_j, dim_k, 2);
135 c = self.update_factor(&unfold_2, &kr_ba, r)?;
136
137 let recon = self.reconstruct_flat(&a, &b, &c, &vec![1.0; r], dim_i, dim_j, dim_k);
139 let err_vec: Vec<f64> = tensor
140 .iter()
141 .zip(recon.iter())
142 .map(|(t, r)| t - r)
143 .collect();
144 let error = frobenius(&err_vec) / (tensor_norm + 1e-30);
145
146 if (prev_error - error).abs() < self.tol {
147 break;
148 }
149 prev_error = error;
150 }
151
152 let mut lambdas = vec![1.0_f64; r];
154 for col in 0..r {
155 let na: f64 = a.iter().map(|row| row[col] * row[col]).sum::<f64>().sqrt();
156 let nb: f64 = b.iter().map(|row| row[col] * row[col]).sum::<f64>().sqrt();
157 let nc: f64 = c.iter().map(|row| row[col] * row[col]).sum::<f64>().sqrt();
158 lambdas[col] = na * nb * nc;
159 if na > 1e-15 {
160 for row in &mut a {
161 row[col] /= na;
162 }
163 }
164 if nb > 1e-15 {
165 for row in &mut b {
166 row[col] /= nb;
167 }
168 }
169 if nc > 1e-15 {
170 for row in &mut c {
171 row[col] /= nc;
172 }
173 }
174 }
175
176 let recon = self.reconstruct_flat(&a, &b, &c, &lambdas, dim_i, dim_j, dim_k);
177 let err_vec: Vec<f64> = tensor
178 .iter()
179 .zip(recon.iter())
180 .map(|(t, r)| t - r)
181 .collect();
182 let approx_error = frobenius(&err_vec) / (tensor_norm + 1e-30);
183
184 Ok(CpFactors {
185 factor_a: a,
186 factor_b: b,
187 factor_c: c,
188 lambdas,
189 approx_error,
190 iterations,
191 })
192 }
193
194 pub fn reconstruct(factors: &CpFactors, dim_i: usize, dim_j: usize, dim_k: usize) -> Vec<f64> {
196 let r = factors.lambdas.len();
197 let mut out = vec![0.0_f64; dim_i * dim_j * dim_k];
198 for comp in 0..r {
199 let lam = factors.lambdas[comp];
200 for i in 0..dim_i {
201 for j in 0..dim_j {
202 for k in 0..dim_k {
203 out[i * dim_j * dim_k + j * dim_k + k] += lam
204 * factors.factor_a[i][comp]
205 * factors.factor_b[j][comp]
206 * factors.factor_c[k][comp];
207 }
208 }
209 }
210 }
211 out
212 }
213
214 fn mode_unfold(
217 &self,
218 tensor: &[f64],
219 di: usize,
220 dj: usize,
221 dk: usize,
222 mode: usize,
223 ) -> Vec<Vec<f64>> {
224 match mode {
225 0 => {
226 (0..di)
228 .map(|i| {
229 let mut row = Vec::with_capacity(dj * dk);
230 for j in 0..dj {
231 for k in 0..dk {
232 row.push(tensor[i * dj * dk + j * dk + k]);
233 }
234 }
235 row
236 })
237 .collect()
238 }
239 1 => {
240 (0..dj)
242 .map(|j| {
243 let mut row = Vec::with_capacity(di * dk);
244 for i in 0..di {
245 for k in 0..dk {
246 row.push(tensor[i * dj * dk + j * dk + k]);
247 }
248 }
249 row
250 })
251 .collect()
252 }
253 _ => {
254 (0..dk)
256 .map(|k| {
257 let mut row = Vec::with_capacity(di * dj);
258 for i in 0..di {
259 for j in 0..dj {
260 row.push(tensor[i * dj * dk + j * dk + k]);
261 }
262 }
263 row
264 })
265 .collect()
266 }
267 }
268 }
269
270 fn update_factor(
271 &self,
272 unfold: &[Vec<f64>],
273 kr: &[Vec<f64>],
274 _rank: usize,
275 ) -> Result<Vec<Vec<f64>>> {
276 let product = mat_mul(unfold, kr);
278 let krtk = mat_mul(&mat_t(kr), kr);
279 let n = krtk.len();
280 let mut aug: Vec<Vec<f64>> = krtk
282 .iter()
283 .enumerate()
284 .map(|(i, row)| {
285 let mut r = row.clone();
286 for j in 0..n {
287 r.push(if i == j { 1.0 } else { 0.0 });
288 }
289 r
290 })
291 .collect();
292 for col in 0..n {
293 let mut max_row = col;
294 let mut max_val = aug[col][col].abs();
295 for row in (col + 1)..n {
296 let v = aug[row][col].abs();
297 if v > max_val {
298 max_val = v;
299 max_row = row;
300 }
301 }
302 if max_val < 1e-14 {
303 aug[col][col] += 1e-10;
305 }
306 aug.swap(col, max_row);
307 let pivot = aug[col][col];
308 for j in 0..(2 * n) {
309 aug[col][j] /= pivot;
310 }
311 for row in 0..n {
312 if row == col {
313 continue;
314 }
315 let factor = aug[row][col];
316 for j in 0..(2 * n) {
317 aug[row][j] -= factor * aug[col][j];
318 }
319 }
320 }
321 let inv: Vec<Vec<f64>> = aug.iter().map(|r| r[n..].to_vec()).collect();
322 Ok(mat_mul(&product, &inv))
323 }
324
325 fn reconstruct_flat(
326 &self,
327 a: &[Vec<f64>],
328 b: &[Vec<f64>],
329 c: &[Vec<f64>],
330 lambdas: &[f64],
331 di: usize,
332 dj: usize,
333 dk: usize,
334 ) -> Vec<f64> {
335 let r = lambdas.len();
336 let mut out = vec![0.0_f64; di * dj * dk];
337 for comp in 0..r {
338 let lam = lambdas[comp];
339 for i in 0..di {
340 for j in 0..dj {
341 for k in 0..dk {
342 out[i * dj * dk + j * dk + k] += lam * a[i][comp] * b[j][comp] * c[k][comp];
343 }
344 }
345 }
346 }
347 out
348 }
349}
350
351#[derive(Debug, Clone)]
357pub struct TuckerFactors {
358 pub core: Vec<f64>,
360 pub core_shape: (usize, usize, usize),
361 pub factors: Vec<Vec<Vec<f64>>>,
363 pub compression_ratio: f64,
365 pub approx_error: f64,
367}
368
369#[derive(Debug, Clone)]
374pub struct EoTuckerDecomposition {
375 pub ranks: (usize, usize, usize),
377 pub svd_iters: usize,
378 pub seed: u64,
379}
380
381impl EoTuckerDecomposition {
382 pub fn new(ranks: (usize, usize, usize)) -> Self {
383 Self {
384 ranks,
385 svd_iters: 50,
386 seed: 42,
387 }
388 }
389
390 pub fn decompose(
392 &self,
393 tensor: &[f64],
394 dim_i: usize,
395 dim_j: usize,
396 dim_k: usize,
397 ) -> Result<TuckerFactors> {
398 let total = dim_i * dim_j * dim_k;
399 if tensor.len() != total {
400 return Err(TensorError::compute_error_simple(format!(
401 "Tucker decomposition: tensor length {} != {}",
402 tensor.len(),
403 total,
404 )));
405 }
406
407 let (r1, r2, r3) = self.ranks;
408 let tensor_norm = frobenius(tensor);
409
410 let unfold_0 = self.mode_unfold(tensor, dim_i, dim_j, dim_k, 0);
412 let (u1, _s1, _vt1) = truncated_svd(&unfold_0, r1, self.svd_iters, self.seed)?;
413
414 let unfold_1 = self.mode_unfold(tensor, dim_i, dim_j, dim_k, 1);
416 let (u2, _s2, _vt2) = truncated_svd(&unfold_1, r2, self.svd_iters, self.seed + 1)?;
417
418 let unfold_2 = self.mode_unfold(tensor, dim_i, dim_j, dim_k, 2);
420 let (u3, _s3, _vt3) = truncated_svd(&unfold_2, r3, self.svd_iters, self.seed + 2)?;
421
422 let core = self.compute_core(tensor, dim_i, dim_j, dim_k, &u1, &u2, &u3);
424 let actual_r1 = u1.first().map_or(0, |r| r.len());
425 let actual_r2 = u2.first().map_or(0, |r| r.len());
426 let actual_r3 = u3.first().map_or(0, |r| r.len());
427
428 let original_size = total;
430 let compressed_size = actual_r1 * actual_r2 * actual_r3
431 + dim_i * actual_r1
432 + dim_j * actual_r2
433 + dim_k * actual_r3;
434 let compression_ratio = if compressed_size > 0 {
435 original_size as f64 / compressed_size as f64
436 } else {
437 0.0
438 };
439
440 let recon = Self::reconstruct_from_parts(
442 &core, actual_r1, actual_r2, actual_r3, &u1, &u2, &u3, dim_i, dim_j, dim_k,
443 );
444 let err_vec: Vec<f64> = tensor
445 .iter()
446 .zip(recon.iter())
447 .map(|(t, r)| t - r)
448 .collect();
449 let approx_error = frobenius(&err_vec) / (tensor_norm + 1e-30);
450
451 Ok(TuckerFactors {
452 core,
453 core_shape: (actual_r1, actual_r2, actual_r3),
454 factors: vec![u1, u2, u3],
455 compression_ratio,
456 approx_error,
457 })
458 }
459
460 pub fn reconstruct(
462 factors: &TuckerFactors,
463 dim_i: usize,
464 dim_j: usize,
465 dim_k: usize,
466 ) -> Vec<f64> {
467 let (r1, r2, r3) = factors.core_shape;
468 Self::reconstruct_from_parts(
469 &factors.core,
470 r1,
471 r2,
472 r3,
473 &factors.factors[0],
474 &factors.factors[1],
475 &factors.factors[2],
476 dim_i,
477 dim_j,
478 dim_k,
479 )
480 }
481
482 fn mode_unfold(
485 &self,
486 tensor: &[f64],
487 di: usize,
488 dj: usize,
489 dk: usize,
490 mode: usize,
491 ) -> Vec<Vec<f64>> {
492 match mode {
493 0 => (0..di)
494 .map(|i| {
495 let mut row = Vec::with_capacity(dj * dk);
496 for j in 0..dj {
497 for k in 0..dk {
498 row.push(tensor[i * dj * dk + j * dk + k]);
499 }
500 }
501 row
502 })
503 .collect(),
504 1 => (0..dj)
505 .map(|j| {
506 let mut row = Vec::with_capacity(di * dk);
507 for i in 0..di {
508 for k in 0..dk {
509 row.push(tensor[i * dj * dk + j * dk + k]);
510 }
511 }
512 row
513 })
514 .collect(),
515 _ => (0..dk)
516 .map(|k| {
517 let mut row = Vec::with_capacity(di * dj);
518 for i in 0..di {
519 for j in 0..dj {
520 row.push(tensor[i * dj * dk + j * dk + k]);
521 }
522 }
523 row
524 })
525 .collect(),
526 }
527 }
528
529 fn compute_core(
530 &self,
531 tensor: &[f64],
532 di: usize,
533 dj: usize,
534 dk: usize,
535 u1: &[Vec<f64>],
536 u2: &[Vec<f64>],
537 u3: &[Vec<f64>],
538 ) -> Vec<f64> {
539 let r1 = u1.first().map_or(0, |r| r.len());
540 let r2 = u2.first().map_or(0, |r| r.len());
541 let r3 = u3.first().map_or(0, |r| r.len());
542 let mut core = vec![0.0_f64; r1 * r2 * r3];
543 for i in 0..di {
545 for j in 0..dj {
546 let tij_base = i * dj * dk + j * dk;
547 for k in 0..dk {
548 let val = tensor[tij_base + k];
549 if val.abs() < 1e-30 {
550 continue;
551 }
552 for a in 0..r1 {
553 let u1_ia = u1[i][a];
554 if u1_ia.abs() < 1e-30 {
555 continue;
556 }
557 for b in 0..r2 {
558 let u2_jb = u2[j][b];
559 if u2_jb.abs() < 1e-30 {
560 continue;
561 }
562 for c in 0..r3 {
563 core[a * r2 * r3 + b * r3 + c] += val * u1_ia * u2_jb * u3[k][c];
564 }
565 }
566 }
567 }
568 }
569 }
570 core
571 }
572
573 fn reconstruct_from_parts(
574 core: &[f64],
575 r1: usize,
576 r2: usize,
577 r3: usize,
578 u1: &[Vec<f64>],
579 u2: &[Vec<f64>],
580 u3: &[Vec<f64>],
581 di: usize,
582 dj: usize,
583 dk: usize,
584 ) -> Vec<f64> {
585 let mut out = vec![0.0_f64; di * dj * dk];
586 for a in 0..r1 {
588 for b in 0..r2 {
589 for c in 0..r3 {
590 let g_abc = core[a * r2 * r3 + b * r3 + c];
591 if g_abc.abs() < 1e-30 {
592 continue;
593 }
594 for i in 0..di {
595 let ga_u1 = g_abc * u1[i][a];
596 if ga_u1.abs() < 1e-30 {
597 continue;
598 }
599 for j in 0..dj {
600 let gau2 = ga_u1 * u2[j][b];
601 if gau2.abs() < 1e-30 {
602 continue;
603 }
604 for k in 0..dk {
605 out[i * dj * dk + j * dk + k] += gau2 * u3[k][c];
606 }
607 }
608 }
609 }
610 }
611 }
612 out
613 }
614}
615
616#[derive(Debug, Clone)]
622pub struct TtCore {
623 pub data: Vec<f64>,
624 pub shape: (usize, usize, usize), }
626
627#[derive(Debug, Clone)]
629pub struct TtFactors {
630 pub cores: Vec<TtCore>,
631 pub original_shape: Vec<usize>,
633 pub approx_error: f64,
635}
636
637#[derive(Debug, Clone)]
641pub struct TtDecomposition {
642 pub max_rank: usize,
643 pub tol: f64,
644 pub svd_iters: usize,
645 pub seed: u64,
646}
647
648impl TtDecomposition {
649 pub fn new(max_rank: usize) -> Self {
650 Self {
651 max_rank,
652 tol: 1e-8,
653 svd_iters: 50,
654 seed: 42,
655 }
656 }
657
658 pub fn decompose(&self, tensor: &[f64], shape: &[usize]) -> Result<TtFactors> {
661 let total: usize = shape.iter().product();
662 if tensor.len() != total {
663 return Err(TensorError::compute_error_simple(format!(
664 "TT decomposition: tensor length {} != product of shape {:?} = {}",
665 tensor.len(),
666 shape,
667 total,
668 )));
669 }
670 if shape.len() < 2 {
671 return Err(TensorError::compute_error_simple(
672 "TT decomposition requires at least 2 dimensions".to_string(),
673 ));
674 }
675
676 let tensor_norm = frobenius(tensor);
677 let d = shape.len();
678 let mut cores: Vec<TtCore> = Vec::with_capacity(d);
679 let mut c = tensor.to_vec();
680 let mut r_prev = 1_usize;
681
682 let mut remaining: usize = total;
684
685 for k in 0..(d - 1) {
686 let n_k = shape[k];
687 remaining /= n_k;
688 let rows = r_prev * n_k;
690 let cols = remaining;
691 let mat: Vec<Vec<f64>> = (0..rows)
692 .map(|i| (0..cols).map(|j| c[i * cols + j]).collect())
693 .collect();
694
695 let rank = self.max_rank.min(rows).min(cols);
696 let (u, s, vt) = truncated_svd(&mat, rank, self.svd_iters, self.seed + k as u64)?;
697
698 let actual_rank = s.len();
699 let mut core_data = vec![0.0_f64; r_prev * n_k * actual_rank];
701 for i in 0..rows {
702 for r in 0..actual_rank {
703 core_data[i * actual_rank + r] =
704 u.get(i).and_then(|row| row.get(r).copied()).unwrap_or(0.0);
705 }
706 }
707 cores.push(TtCore {
708 data: core_data,
709 shape: (r_prev, n_k, actual_rank),
710 });
711
712 let mut new_c = vec![0.0_f64; actual_rank * cols];
714 for r in 0..actual_rank {
715 for j in 0..cols {
716 new_c[r * cols + j] =
717 s[r] * vt.get(r).and_then(|row| row.get(j).copied()).unwrap_or(0.0);
718 }
719 }
720 c = new_c;
721 r_prev = actual_rank;
722 }
723
724 let n_last = shape[d - 1];
726 let mut last_core_data = vec![0.0_f64; r_prev * n_last];
727 let copy_len = c.len().min(r_prev * n_last);
728 last_core_data[..copy_len].copy_from_slice(&c[..copy_len]);
729 cores.push(TtCore {
730 data: last_core_data,
731 shape: (r_prev, n_last, 1),
732 });
733
734 let recon = Self::reconstruct_flat(&cores, shape);
736 let err_vec: Vec<f64> = tensor
737 .iter()
738 .zip(recon.iter())
739 .map(|(t, r)| t - r)
740 .collect();
741 let approx_error = frobenius(&err_vec) / (tensor_norm + 1e-30);
742
743 Ok(TtFactors {
744 cores,
745 original_shape: shape.to_vec(),
746 approx_error,
747 })
748 }
749
750 pub fn reconstruct(factors: &TtFactors) -> Vec<f64> {
752 Self::reconstruct_flat(&factors.cores, &factors.original_shape)
753 }
754
755 pub fn round(&self, factors: &TtFactors, new_max_rank: usize) -> Result<TtFactors> {
757 let full = Self::reconstruct_flat(&factors.cores, &factors.original_shape);
758 let mut dec = self.clone();
759 dec.max_rank = new_max_rank;
760 dec.decompose(&full, &factors.original_shape)
761 }
762
763 fn reconstruct_flat(cores: &[TtCore], shape: &[usize]) -> Vec<f64> {
764 let total: usize = shape.iter().product();
765 let d = shape.len();
766 let mut result = vec![0.0_f64; total];
767
768 let mut indices = vec![0_usize; d];
770 for flat_idx in 0..total {
771 let mut rem = flat_idx;
773 for k in (0..d).rev() {
774 indices[k] = rem % shape[k];
775 rem /= shape[k];
776 }
777
778 let mut vec_cur: Vec<f64> = vec![1.0];
781 let mut cur_cols = 1_usize;
782
783 for (k, core) in cores.iter().enumerate() {
784 let (r_left, _n_k, r_right) = core.shape;
785 let ik = indices[k];
786 let mut new_vec = vec![0.0_f64; r_right];
790 debug_assert_eq!(cur_cols, r_left);
792 for a in 0..r_left {
793 let v_a = vec_cur[a];
794 if v_a.abs() < 1e-30 {
795 continue;
796 }
797 for b in 0..r_right {
798 let core_val = core.data[a * cores[k].shape.1 * r_right + ik * r_right + b];
799 new_vec[b] += v_a * core_val;
800 }
801 }
802 vec_cur = new_vec;
803 cur_cols = r_right;
804 }
805
806 result[flat_idx] = vec_cur[0];
807 }
808 result
809 }
810}
811
812#[derive(Debug, Clone)]
818pub struct CodebookResult {
819 pub codebook: Vec<f64>,
821 pub indices: Vec<usize>,
823 pub compression_ratio: f64,
825 pub mse: f64,
827}
828
829#[derive(Debug, Clone)]
834pub struct CodebookQuantization {
835 pub n_clusters: usize,
836 pub max_iters: usize,
837 pub seed: u64,
838}
839
840impl CodebookQuantization {
841 pub fn new(n_clusters: usize) -> Self {
842 Self {
843 n_clusters,
844 max_iters: 100,
845 seed: 42,
846 }
847 }
848
849 pub fn quantize(&self, weights: &[f64]) -> Result<CodebookResult> {
851 if weights.is_empty() {
852 return Err(TensorError::compute_error_simple(
853 "CodebookQuantization: empty weight vector".to_string(),
854 ));
855 }
856 let k = self.n_clusters.min(weights.len());
857 if k == 0 {
858 return Err(TensorError::compute_error_simple(
859 "CodebookQuantization: n_clusters must be > 0".to_string(),
860 ));
861 }
862
863 let mut rng = StdRng::seed_from_u64(self.seed);
864
865 let mut centroids: Vec<f64> = Vec::with_capacity(k);
867 let first_idx = rng.random_range(0..weights.len());
869 centroids.push(weights[first_idx]);
870
871 for _ in 1..k {
872 let mut dist_sq: Vec<f64> = weights
873 .iter()
874 .map(|w| {
875 centroids
876 .iter()
877 .map(|c| (w - c) * (w - c))
878 .fold(f64::MAX, f64::min)
879 })
880 .collect();
881 let total: f64 = dist_sq.iter().sum();
882 if total < 1e-30 {
883 centroids.push(weights[rng.random_range(0..weights.len())]);
885 continue;
886 }
887 for d in &mut dist_sq {
889 *d /= total;
890 }
891 let r: f64 = rng.random_range(0.0..1.0);
892 let mut cumsum = 0.0;
893 let mut chosen = 0;
894 for (i, &d) in dist_sq.iter().enumerate() {
895 cumsum += d;
896 if cumsum >= r {
897 chosen = i;
898 break;
899 }
900 }
901 centroids.push(weights[chosen]);
902 }
903
904 let mut indices = vec![0_usize; weights.len()];
906 for _iter in 0..self.max_iters {
907 let mut changed = false;
909 for (i, w) in weights.iter().enumerate() {
910 let mut best = 0;
911 let mut best_dist = f64::MAX;
912 for (c, centroid) in centroids.iter().enumerate() {
913 let d = (w - centroid) * (w - centroid);
914 if d < best_dist {
915 best_dist = d;
916 best = c;
917 }
918 }
919 if indices[i] != best {
920 changed = true;
921 indices[i] = best;
922 }
923 }
924 if !changed {
925 break;
926 }
927
928 let mut sums = vec![0.0_f64; k];
930 let mut counts = vec![0_usize; k];
931 for (i, w) in weights.iter().enumerate() {
932 sums[indices[i]] += w;
933 counts[indices[i]] += 1;
934 }
935 for c in 0..k {
936 if counts[c] > 0 {
937 centroids[c] = sums[c] / counts[c] as f64;
938 }
939 }
940 }
941
942 let mse: f64 = weights
944 .iter()
945 .zip(indices.iter())
946 .map(|(w, &idx)| {
947 let d = w - centroids[idx];
948 d * d
949 })
950 .sum::<f64>()
951 / weights.len() as f64;
952
953 let index_bits = (k as f64).log2().ceil().max(1.0);
955 let original_bits = weights.len() as f64 * 64.0;
956 let compressed_bits = k as f64 * 64.0 + weights.len() as f64 * index_bits;
957 let compression_ratio = if compressed_bits > 0.0 {
958 original_bits / compressed_bits
959 } else {
960 0.0
961 };
962
963 Ok(CodebookResult {
964 codebook: centroids,
965 indices,
966 compression_ratio,
967 mse,
968 })
969 }
970
971 pub fn dequantize(codebook: &[f64], indices: &[usize]) -> Vec<f64> {
973 indices
974 .iter()
975 .map(|&idx| {
976 if idx < codebook.len() {
977 codebook[idx]
978 } else {
979 0.0
980 }
981 })
982 .collect()
983 }
984}
985
986#[derive(Debug, Clone)]
992pub struct PqCodes {
993 pub codes: Vec<Vec<usize>>,
995 pub codebooks: Vec<Vec<Vec<f64>>>,
997 pub n_subquantizers: usize,
998 pub sub_dim: usize,
999}
1000
1001#[derive(Debug, Clone)]
1006pub struct ProductQuantization {
1007 pub n_subquantizers: usize,
1008 pub n_centroids: usize,
1009 pub max_iters: usize,
1010 pub seed: u64,
1011}
1012
1013impl ProductQuantization {
1014 pub fn new(n_subquantizers: usize, n_centroids: usize) -> Self {
1015 Self {
1016 n_subquantizers,
1017 n_centroids,
1018 max_iters: 50,
1019 seed: 42,
1020 }
1021 }
1022
1023 pub fn encode(&self, vectors: &[f64], n_vectors: usize, dim: usize) -> Result<PqCodes> {
1026 if vectors.len() != n_vectors * dim {
1027 return Err(TensorError::compute_error_simple(format!(
1028 "PQ encode: expected {} elements, got {}",
1029 n_vectors * dim,
1030 vectors.len(),
1031 )));
1032 }
1033 if dim % self.n_subquantizers != 0 {
1034 return Err(TensorError::compute_error_simple(format!(
1035 "PQ encode: dim {} not divisible by n_subquantizers {}",
1036 dim, self.n_subquantizers,
1037 )));
1038 }
1039 let sub_dim = dim / self.n_subquantizers;
1040 let m_count = self.n_subquantizers;
1041 let k = self.n_centroids.min(n_vectors);
1042
1043 let mut codebooks: Vec<Vec<Vec<f64>>> = Vec::with_capacity(m_count);
1044 let mut codes: Vec<Vec<usize>> = vec![vec![0_usize; m_count]; n_vectors];
1045 let mut rng = StdRng::seed_from_u64(self.seed);
1046
1047 for m in 0..m_count {
1048 let offset = m * sub_dim;
1049 let sub_vecs: Vec<Vec<f64>> = (0..n_vectors)
1051 .map(|i| {
1052 (0..sub_dim)
1053 .map(|d| vectors[i * dim + offset + d])
1054 .collect()
1055 })
1056 .collect();
1057
1058 let mut centroids: Vec<Vec<f64>> = (0..k)
1060 .map(|_| {
1061 let idx = rng.random_range(0..n_vectors);
1062 sub_vecs[idx].clone()
1063 })
1064 .collect();
1065
1066 let mut assignments = vec![0_usize; n_vectors];
1067 for _iter in 0..self.max_iters {
1068 let mut changed = false;
1070 for i in 0..n_vectors {
1071 let mut best = 0;
1072 let mut best_dist = f64::MAX;
1073 for (c, centroid) in centroids.iter().enumerate() {
1074 let d: f64 = sub_vecs[i]
1075 .iter()
1076 .zip(centroid.iter())
1077 .map(|(a, b)| (a - b) * (a - b))
1078 .sum();
1079 if d < best_dist {
1080 best_dist = d;
1081 best = c;
1082 }
1083 }
1084 if assignments[i] != best {
1085 changed = true;
1086 assignments[i] = best;
1087 }
1088 }
1089 if !changed {
1090 break;
1091 }
1092 let mut sums = vec![vec![0.0_f64; sub_dim]; k];
1094 let mut counts = vec![0_usize; k];
1095 for (i, &a) in assignments.iter().enumerate() {
1096 for d in 0..sub_dim {
1097 sums[a][d] += sub_vecs[i][d];
1098 }
1099 counts[a] += 1;
1100 }
1101 for c in 0..k {
1102 if counts[c] > 0 {
1103 for d in 0..sub_dim {
1104 centroids[c][d] = sums[c][d] / counts[c] as f64;
1105 }
1106 }
1107 }
1108 }
1109
1110 codebooks.push(centroids);
1111 for (i, &a) in assignments.iter().enumerate() {
1112 codes[i][m] = a;
1113 }
1114 }
1115
1116 Ok(PqCodes {
1117 codes,
1118 codebooks,
1119 n_subquantizers: m_count,
1120 sub_dim,
1121 })
1122 }
1123
1124 pub fn search_adc(&self, query: &[f64], pq_codes: &PqCodes) -> Result<Vec<f64>> {
1127 let dim = pq_codes.n_subquantizers * pq_codes.sub_dim;
1128 if query.len() != dim {
1129 return Err(TensorError::compute_error_simple(format!(
1130 "PQ search: query dim {} != expected {}",
1131 query.len(),
1132 dim,
1133 )));
1134 }
1135
1136 let m_count = pq_codes.n_subquantizers;
1137 let sub_dim = pq_codes.sub_dim;
1138
1139 let mut dist_table: Vec<Vec<f64>> = Vec::with_capacity(m_count);
1141 for m in 0..m_count {
1142 let offset = m * sub_dim;
1143 let q_sub: Vec<f64> = (0..sub_dim).map(|d| query[offset + d]).collect();
1144 let table: Vec<f64> = pq_codes.codebooks[m]
1145 .iter()
1146 .map(|c| {
1147 q_sub
1148 .iter()
1149 .zip(c.iter())
1150 .map(|(a, b)| (a - b) * (a - b))
1151 .sum()
1152 })
1153 .collect();
1154 dist_table.push(table);
1155 }
1156
1157 let n_vectors = pq_codes.codes.len();
1159 let distances: Vec<f64> = (0..n_vectors)
1160 .map(|i| {
1161 (0..m_count)
1162 .map(|m| {
1163 let idx = pq_codes.codes[i][m];
1164 if idx < dist_table[m].len() {
1165 dist_table[m][idx]
1166 } else {
1167 0.0
1168 }
1169 })
1170 .sum()
1171 })
1172 .collect();
1173
1174 Ok(distances)
1175 }
1176}
1177
1178#[derive(Debug, Clone)]
1184pub struct HardwareProfile {
1185 pub name: String,
1187 pub memory_budget_bytes: usize,
1189 pub compute_budget_mflops: f64,
1191 pub latency_target_ms: f64,
1193 pub supported_int_bits: Vec<usize>,
1195}
1196
1197impl HardwareProfile {
1198 pub fn new(name: &str, memory_bytes: usize, mflops: f64, latency_ms: f64) -> Self {
1199 Self {
1200 name: name.to_string(),
1201 memory_budget_bytes: memory_bytes,
1202 compute_budget_mflops: mflops,
1203 latency_target_ms: latency_ms,
1204 supported_int_bits: vec![8, 16, 32],
1205 }
1206 }
1207
1208 pub fn cortex_m4() -> Self {
1210 Self::new("Cortex-M4", 256 * 1024, 100.0, 50.0)
1211 }
1212
1213 pub fn mobile_midrange() -> Self {
1215 Self::new("Mobile-MidRange", 2 * 1024 * 1024 * 1024, 50_000.0, 20.0)
1216 }
1217
1218 pub fn jetson_nano() -> Self {
1220 Self::new("Jetson-Nano", 4 * 1024 * 1024 * 1024, 472_000.0, 10.0)
1221 }
1222}
1223
1224#[derive(Debug, Clone)]
1226pub struct EoArchCandidate {
1227 pub width_mult: f64,
1229 pub depth_mult: f64,
1231 pub estimated_latency_ms: f64,
1233 pub estimated_memory_bytes: usize,
1235 pub estimated_mflops: f64,
1237 pub accuracy_proxy: f64,
1239}
1240
1241#[derive(Debug, Clone)]
1247pub struct HardwareAwareSearch {
1248 pub profile: HardwareProfile,
1249 pub base_mflops: f64,
1251 pub base_params: usize,
1253 pub width_mults: Vec<f64>,
1255 pub depth_mults: Vec<f64>,
1257 pub seed: u64,
1258}
1259
1260impl HardwareAwareSearch {
1261 pub fn new(profile: HardwareProfile, base_mflops: f64, base_params: usize) -> Self {
1262 Self {
1263 profile,
1264 base_mflops,
1265 base_params,
1266 width_mults: vec![0.25, 0.5, 0.75, 1.0],
1267 depth_mults: vec![0.5, 0.75, 1.0],
1268 seed: 42,
1269 }
1270 }
1271
1272 pub fn generate_candidates(&self) -> Vec<EoArchCandidate> {
1274 let mut rng = StdRng::seed_from_u64(self.seed);
1275 let mut candidates = Vec::new();
1276
1277 for &wm in &self.width_mults {
1278 for &dm in &self.depth_mults {
1279 let mflops = self.base_mflops * wm * wm * dm;
1281 let mem = (self.base_params as f64 * wm * dm * 4.0) as usize; let latency = if self.profile.compute_budget_mflops > 0.0 {
1285 mflops / self.profile.compute_budget_mflops * 1000.0 } else {
1287 f64::MAX
1288 };
1289 let accuracy_proxy =
1291 (0.5 + 0.4 * (wm * dm).powf(0.3) + rng.random_range(-0.02..0.02))
1292 .clamp(0.0, 1.0);
1293
1294 candidates.push(EoArchCandidate {
1295 width_mult: wm,
1296 depth_mult: dm,
1297 estimated_latency_ms: latency,
1298 estimated_memory_bytes: mem,
1299 estimated_mflops: mflops,
1300 accuracy_proxy,
1301 });
1302 }
1303 }
1304 candidates
1305 }
1306
1307 pub fn filter_feasible(&self, candidates: &[EoArchCandidate]) -> Vec<EoArchCandidate> {
1309 candidates
1310 .iter()
1311 .filter(|c| {
1312 c.estimated_latency_ms <= self.profile.latency_target_ms
1313 && c.estimated_memory_bytes <= self.profile.memory_budget_bytes
1314 })
1315 .cloned()
1316 .collect()
1317 }
1318
1319 pub fn pareto_frontier(&self, candidates: &[EoArchCandidate]) -> Vec<usize> {
1322 let n = candidates.len();
1323 let mut is_dominated = vec![false; n];
1324
1325 for i in 0..n {
1326 if is_dominated[i] {
1327 continue;
1328 }
1329 for j in 0..n {
1330 if i == j || is_dominated[j] {
1331 continue;
1332 }
1333 let j_better_acc = candidates[j].accuracy_proxy >= candidates[i].accuracy_proxy;
1335 let j_better_lat =
1336 candidates[j].estimated_latency_ms <= candidates[i].estimated_latency_ms;
1337 let j_strictly_better = candidates[j].accuracy_proxy > candidates[i].accuracy_proxy
1338 || candidates[j].estimated_latency_ms < candidates[i].estimated_latency_ms;
1339 if j_better_acc && j_better_lat && j_strictly_better {
1340 is_dominated[i] = true;
1341 break;
1342 }
1343 }
1344 }
1345
1346 (0..n).filter(|&i| !is_dominated[i]).collect()
1347 }
1348
1349 pub fn search(&self) -> Vec<EoArchCandidate> {
1351 let all = self.generate_candidates();
1352 let feasible = self.filter_feasible(&all);
1353 let pareto_idxs = self.pareto_frontier(&feasible);
1354 pareto_idxs.iter().map(|&i| feasible[i].clone()).collect()
1355 }
1356}
1357
1358#[derive(Debug, Clone)]
1364pub struct EoSlimmableLinear {
1365 pub weight: Vec<f64>,
1367 pub bias: Vec<f64>,
1369 pub in_features: usize,
1370 pub out_features: usize,
1371}
1372
1373impl EoSlimmableLinear {
1374 pub fn new(in_features: usize, out_features: usize, seed: u64) -> Self {
1375 let mut rng = StdRng::seed_from_u64(seed);
1376 let limit = (6.0 / (in_features + out_features) as f64).sqrt();
1378 let weight: Vec<f64> = (0..out_features * in_features)
1379 .map(|_| rng.random_range(-limit..limit))
1380 .collect();
1381 let bias = vec![0.0_f64; out_features];
1382 Self {
1383 weight,
1384 bias,
1385 in_features,
1386 out_features,
1387 }
1388 }
1389
1390 pub fn forward_at_width(&self, input: &[f64], width_mult: f64) -> Result<Vec<f64>> {
1393 let active_in = ((self.in_features as f64 * width_mult).ceil() as usize)
1394 .max(1)
1395 .min(self.in_features);
1396 let active_out = ((self.out_features as f64 * width_mult).ceil() as usize)
1397 .max(1)
1398 .min(self.out_features);
1399
1400 if input.len() < active_in {
1401 return Err(TensorError::compute_error_simple(format!(
1402 "SlimmableLinear: input len {} < active_in {}",
1403 input.len(),
1404 active_in,
1405 )));
1406 }
1407
1408 let mut output = Vec::with_capacity(active_out);
1409 for o in 0..active_out {
1410 let mut val = self.bias[o];
1411 for i in 0..active_in {
1412 val += self.weight[o * self.in_features + i] * input[i];
1413 }
1414 output.push(val);
1415 }
1416 Ok(output)
1417 }
1418}
1419
1420#[derive(Debug, Clone)]
1426pub struct DynamicWidthNetwork {
1427 pub layers: Vec<EoSlimmableLinear>,
1428 pub width_options: Vec<f64>,
1430}
1431
1432impl DynamicWidthNetwork {
1433 pub fn new(layer_sizes: &[usize], seed: u64) -> Result<Self> {
1435 if layer_sizes.len() < 2 {
1436 return Err(TensorError::compute_error_simple(
1437 "DynamicWidthNetwork: need at least 2 layer sizes".to_string(),
1438 ));
1439 }
1440 let mut layers = Vec::with_capacity(layer_sizes.len() - 1);
1441 for i in 0..(layer_sizes.len() - 1) {
1442 layers.push(EoSlimmableLinear::new(
1443 layer_sizes[i],
1444 layer_sizes[i + 1],
1445 seed + i as u64,
1446 ));
1447 }
1448 Ok(Self {
1449 layers,
1450 width_options: vec![0.25, 0.5, 0.75, 1.0],
1451 })
1452 }
1453
1454 pub fn forward_at_width(&self, input: &[f64], width_mult: f64) -> Result<Vec<f64>> {
1456 let mut x = input.to_vec();
1457 for (idx, layer) in self.layers.iter().enumerate() {
1458 x = layer.forward_at_width(&x, width_mult)?;
1459 if idx < self.layers.len() - 1 {
1461 for v in &mut x {
1462 if *v < 0.0 {
1463 *v = 0.0;
1464 }
1465 }
1466 }
1467 }
1468 Ok(x)
1469 }
1470
1471 pub fn inplace_distillation_loss(&self, input: &[f64], width_mult: f64) -> Result<f64> {
1474 let teacher_out = self.forward_at_width(input, 1.0)?;
1475 let student_out = self.forward_at_width(input, width_mult)?;
1476 let n = teacher_out.len().min(student_out.len());
1477 let loss: f64 = (0..n)
1478 .map(|i| (teacher_out[i] - student_out[i]).powi(2))
1479 .sum::<f64>()
1480 / n.max(1) as f64;
1481 Ok(loss)
1482 }
1483}
1484
1485#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1492pub struct FixedPoint {
1493 pub raw: i32,
1494 pub frac_bits: u8,
1495}
1496
1497impl FixedPoint {
1498 pub fn from_f64(value: f64, frac_bits: u8) -> Self {
1500 let scale = (1_i64 << frac_bits) as f64;
1501 let raw = (value * scale)
1502 .round()
1503 .clamp(i32::MIN as f64, i32::MAX as f64) as i32;
1504 Self { raw, frac_bits }
1505 }
1506
1507 pub fn to_f64(self) -> f64 {
1509 self.raw as f64 / (1_i64 << self.frac_bits) as f64
1510 }
1511
1512 #[allow(clippy::should_implement_trait)]
1514 pub fn mul(self, other: Self) -> Self {
1515 let product = (self.raw as i64) * (other.raw as i64);
1516 let shifted = product >> self.frac_bits;
1517 Self {
1518 raw: shifted.clamp(i32::MIN as i64, i32::MAX as i64) as i32,
1519 frac_bits: self.frac_bits,
1520 }
1521 }
1522
1523 #[allow(clippy::should_implement_trait)]
1525 pub fn add(self, other: Self) -> Self {
1526 Self {
1527 raw: self.raw.saturating_add(other.raw),
1528 frac_bits: self.frac_bits,
1529 }
1530 }
1531}
1532
1533#[derive(Debug, Clone)]
1537pub struct IntegerLinear {
1538 pub weight_i8: Vec<i8>,
1540 pub bias_i32: Vec<i32>,
1542 pub input_scale: f64,
1544 pub weight_scale: f64,
1545 pub output_scale: f64,
1546 pub in_features: usize,
1547 pub out_features: usize,
1548}
1549
1550impl IntegerLinear {
1551 pub fn from_float(
1554 weights: &[f64],
1555 bias: &[f64],
1556 in_features: usize,
1557 out_features: usize,
1558 ) -> Result<Self> {
1559 if weights.len() != out_features * in_features {
1560 return Err(TensorError::compute_error_simple(format!(
1561 "IntegerLinear: weight size {} != {}x{}",
1562 weights.len(),
1563 out_features,
1564 in_features,
1565 )));
1566 }
1567 if bias.len() != out_features {
1568 return Err(TensorError::compute_error_simple(format!(
1569 "IntegerLinear: bias size {} != {}",
1570 bias.len(),
1571 out_features,
1572 )));
1573 }
1574
1575 let w_max = weights.iter().map(|w| w.abs()).fold(0.0_f64, f64::max);
1577 let weight_scale = if w_max > 1e-30 { w_max / 127.0 } else { 1.0 };
1578 let weight_i8: Vec<i8> = weights
1579 .iter()
1580 .map(|&w| (w / weight_scale).round().clamp(-128.0, 127.0) as i8)
1581 .collect();
1582
1583 let input_scale = 1.0 / 127.0;
1585
1586 let bias_scale = input_scale * weight_scale;
1588 let bias_i32: Vec<i32> = bias
1589 .iter()
1590 .map(|&b| {
1591 if bias_scale > 1e-30 {
1592 (b / bias_scale)
1593 .round()
1594 .clamp(i32::MIN as f64, i32::MAX as f64) as i32
1595 } else {
1596 0
1597 }
1598 })
1599 .collect();
1600
1601 let output_scale = input_scale * weight_scale;
1602
1603 Ok(Self {
1604 weight_i8,
1605 bias_i32,
1606 input_scale,
1607 weight_scale,
1608 output_scale,
1609 in_features,
1610 out_features,
1611 })
1612 }
1613
1614 pub fn forward_int(&self, input_i8: &[i8]) -> Result<Vec<i32>> {
1616 if input_i8.len() < self.in_features {
1617 return Err(TensorError::compute_error_simple(format!(
1618 "IntegerLinear forward: input len {} < in_features {}",
1619 input_i8.len(),
1620 self.in_features,
1621 )));
1622 }
1623 let mut output = Vec::with_capacity(self.out_features);
1624 for o in 0..self.out_features {
1625 let mut acc: i32 = self.bias_i32[o];
1626 for i in 0..self.in_features {
1627 acc = acc.saturating_add(
1628 (self.weight_i8[o * self.in_features + i] as i32) * (input_i8[i] as i32),
1629 );
1630 }
1631 output.push(acc);
1632 }
1633 Ok(output)
1634 }
1635
1636 pub fn forward_float(&self, input: &[f64]) -> Result<Vec<f64>> {
1638 let input_i8: Vec<i8> = input
1639 .iter()
1640 .map(|&x| (x / self.input_scale).round().clamp(-128.0, 127.0) as i8)
1641 .collect();
1642 let acc = self.forward_int(&input_i8)?;
1643 Ok(acc.iter().map(|&a| a as f64 * self.output_scale).collect())
1644 }
1645
1646 pub fn quantized_relu(acc: &[i32]) -> Vec<i32> {
1648 acc.iter().map(|&a| a.max(0)).collect()
1649 }
1650
1651 pub fn quantized_sigmoid_approx(acc: &[i32], scale: f64) -> Vec<i32> {
1654 let frac_bits = 8;
1656 let one = 1 << frac_bits; acc.iter()
1658 .map(|&a| {
1659 let x = a as f64 * scale;
1660 let y = if x < -4.0 {
1661 0.0
1662 } else if x > 4.0 {
1663 1.0
1664 } else {
1665 (0.125 * x + 0.5).clamp(0.0, 1.0)
1667 };
1668 (y * one as f64).round() as i32
1669 })
1670 .collect()
1671 }
1672}
1673
1674#[derive(Debug, Clone, PartialEq)]
1680pub enum EoLayerType {
1681 Conv {
1682 in_ch: usize,
1683 out_ch: usize,
1684 kernel: usize,
1685 },
1686 Linear {
1687 in_feat: usize,
1688 out_feat: usize,
1689 },
1690 BatchNorm {
1691 channels: usize,
1692 },
1693 Relu,
1694 Pool {
1695 factor: usize,
1696 },
1697 Custom {
1698 name: String,
1699 memory_bytes: usize,
1700 },
1701}
1702
1703#[derive(Debug, Clone)]
1705pub struct EoLayerDesc {
1706 pub name: String,
1707 pub layer_type: EoLayerType,
1708 pub spatial: (usize, usize),
1710 pub batch_size: usize,
1712}
1713
1714#[derive(Debug, Clone)]
1716pub struct EoFusionOp {
1717 pub layer_indices: Vec<usize>,
1719 pub description: String,
1721 pub memory_saved: usize,
1723}
1724
1725#[derive(Debug, Clone)]
1727pub struct EoAllocationPlan {
1728 pub layer_memory: Vec<usize>,
1730 pub checkpoint_layers: Vec<usize>,
1732 pub fusions: Vec<EoFusionOp>,
1734 pub peak_memory_bytes: usize,
1736 pub fits_budget: bool,
1738}
1739
1740#[derive(Debug, Clone)]
1745pub struct MemoryBudgetAllocator {
1746 pub budget_bytes: usize,
1747}
1748
1749impl MemoryBudgetAllocator {
1750 pub fn new(budget_bytes: usize) -> Self {
1751 Self { budget_bytes }
1752 }
1753
1754 pub fn estimate_layer_memory(layer: &EoLayerDesc) -> usize {
1756 let (h, w) = layer.spatial;
1757 let batch = layer.batch_size;
1758 match &layer.layer_type {
1759 EoLayerType::Conv { out_ch, .. } => batch * (*out_ch) * h * w * 4,
1760 EoLayerType::Linear { out_feat, .. } => batch * (*out_feat) * 4,
1761 EoLayerType::BatchNorm { channels } => batch * (*channels) * h * w * 4,
1762 EoLayerType::Relu => batch * h * w * 4, EoLayerType::Pool { factor } => {
1764 let ph = (h + factor - 1) / factor;
1765 let pw = (w + factor - 1) / factor;
1766 batch * ph * pw * 4
1767 }
1768 EoLayerType::Custom { memory_bytes, .. } => *memory_bytes,
1769 }
1770 }
1771
1772 pub fn detect_fusions(layers: &[EoLayerDesc]) -> Vec<EoFusionOp> {
1774 let mut fusions = Vec::new();
1775 let n = layers.len();
1776 let mut i = 0;
1777 while i + 2 < n {
1778 let is_conv = matches!(&layers[i].layer_type, EoLayerType::Conv { .. });
1779 let is_bn = matches!(&layers[i + 1].layer_type, EoLayerType::BatchNorm { .. });
1780 let is_relu = matches!(&layers[i + 2].layer_type, EoLayerType::Relu);
1781 if is_conv && is_bn && is_relu {
1782 let bn_mem = Self::estimate_layer_memory(&layers[i + 1]);
1783 let relu_mem = Self::estimate_layer_memory(&layers[i + 2]);
1784 fusions.push(EoFusionOp {
1785 layer_indices: vec![i, i + 1, i + 2],
1786 description: format!(
1787 "Fuse {}/{}/{} -> single Conv+BN+ReLU",
1788 layers[i].name,
1789 layers[i + 1].name,
1790 layers[i + 2].name
1791 ),
1792 memory_saved: bn_mem + relu_mem,
1793 });
1794 i += 3;
1795 } else {
1796 i += 1;
1797 }
1798 }
1799 fusions
1800 }
1801
1802 pub fn plan_memory(&self, layers: &[EoLayerDesc]) -> EoAllocationPlan {
1804 let layer_memory: Vec<usize> = layers
1805 .iter()
1806 .map(Self::estimate_layer_memory)
1807 .collect();
1808 let fusions = Self::detect_fusions(layers);
1809
1810 let total_no_opt: usize = layer_memory.iter().sum();
1812
1813 let fusion_savings: usize = fusions.iter().map(|f| f.memory_saved).sum();
1815 let total_after_fusion = total_no_opt.saturating_sub(fusion_savings);
1816
1817 let mut checkpoint_layers = Vec::new();
1819 let mut current_mem = total_after_fusion;
1820
1821 if current_mem > self.budget_bytes && layers.len() > 2 {
1822 let mut sorted: Vec<(usize, usize)> = layer_memory
1824 .iter()
1825 .enumerate()
1826 .filter(|&(i, _)| i > 0 && i < layers.len() - 1)
1827 .map(|(i, &m)| (i, m))
1828 .collect();
1829 sorted.sort_by_key(|a| std::cmp::Reverse(a.1));
1830
1831 for (idx, mem) in sorted {
1832 if current_mem <= self.budget_bytes {
1833 break;
1834 }
1835 let savings = mem / 2;
1838 current_mem = current_mem.saturating_sub(savings);
1839 checkpoint_layers.push(idx);
1840 }
1841 checkpoint_layers.sort();
1842 }
1843
1844 let peak_memory_bytes = current_mem;
1845 let fits_budget = peak_memory_bytes <= self.budget_bytes;
1846
1847 EoAllocationPlan {
1848 layer_memory,
1849 checkpoint_layers,
1850 fusions,
1851 peak_memory_bytes,
1852 fits_budget,
1853 }
1854 }
1855}
1856
1857#[derive(Debug, Clone)]
1863pub struct EdgeMetrics {
1864 pub compression_ratio: f64,
1866 pub speedup_factor: f64,
1868 pub memory_footprint_bytes: usize,
1870 pub model_size_bytes: usize,
1872 pub flops: f64,
1874 pub accuracy: f64,
1876}
1877
1878impl EdgeMetrics {
1879 pub fn compute(
1881 original_params: usize,
1882 compressed_params: usize,
1883 original_flops: f64,
1884 compressed_flops: f64,
1885 memory_bytes: usize,
1886 accuracy: f64,
1887 ) -> Self {
1888 let compression_ratio = if compressed_params > 0 {
1889 original_params as f64 / compressed_params as f64
1890 } else {
1891 0.0
1892 };
1893 let speedup_factor = if compressed_flops > 0.0 {
1894 original_flops / compressed_flops
1895 } else {
1896 0.0
1897 };
1898 Self {
1899 compression_ratio,
1900 speedup_factor,
1901 memory_footprint_bytes: memory_bytes,
1902 model_size_bytes: compressed_params * 4, flops: compressed_flops,
1904 accuracy,
1905 }
1906 }
1907
1908 pub fn efficiency_score(&self) -> f64 {
1910 let a = self.accuracy.clamp(0.0, 1.0);
1911 let c = (self.compression_ratio / 10.0).clamp(0.0, 1.0); if a + c > 0.0 {
1913 2.0 * a * c / (a + c)
1914 } else {
1915 0.0
1916 }
1917 }
1918}
1919
1920#[derive(Debug, Clone)]
1922pub struct EdgeReport {
1923 pub model_name: String,
1924 pub target_device: String,
1925 pub metrics: EdgeMetrics,
1926 pub pareto_candidates: Vec<EoArchCandidate>,
1928 pub allocation_plan: Option<EoAllocationPlan>,
1930 pub decomposition_errors: Vec<f64>,
1932}
1933
1934impl EdgeReport {
1935 pub fn new(model_name: &str, target_device: &str, metrics: EdgeMetrics) -> Self {
1936 Self {
1937 model_name: model_name.to_string(),
1938 target_device: target_device.to_string(),
1939 metrics,
1940 pareto_candidates: Vec::new(),
1941 allocation_plan: None,
1942 decomposition_errors: Vec::new(),
1943 }
1944 }
1945
1946 pub fn summary(&self) -> String {
1948 let mut s = String::new();
1949 s.push_str(&format!("=== Edge Report: {} ===\n", self.model_name));
1950 s.push_str(&format!("Target: {}\n", self.target_device));
1951 s.push_str(&format!(
1952 "Compression: {:.2}x\n",
1953 self.metrics.compression_ratio
1954 ));
1955 s.push_str(&format!("Speedup: {:.2}x\n", self.metrics.speedup_factor));
1956 s.push_str(&format!(
1957 "Memory: {} bytes\n",
1958 self.metrics.memory_footprint_bytes
1959 ));
1960 s.push_str(&format!(
1961 "Model size: {} bytes\n",
1962 self.metrics.model_size_bytes
1963 ));
1964 s.push_str(&format!("FLOPs: {:.0}\n", self.metrics.flops));
1965 s.push_str(&format!("Accuracy: {:.4}\n", self.metrics.accuracy));
1966 s.push_str(&format!(
1967 "Efficiency score: {:.4}\n",
1968 self.metrics.efficiency_score()
1969 ));
1970 if !self.pareto_candidates.is_empty() {
1971 s.push_str(&format!(
1972 "Pareto candidates: {}\n",
1973 self.pareto_candidates.len()
1974 ));
1975 }
1976 if let Some(ref plan) = self.allocation_plan {
1977 s.push_str(&format!(
1978 "Memory plan: peak={} bytes, fits={}\n",
1979 plan.peak_memory_bytes, plan.fits_budget
1980 ));
1981 }
1982 s
1983 }
1984
1985 pub fn pareto_analysis_summary(&self) -> Vec<(f64, f64)> {
1987 self.pareto_candidates
1988 .iter()
1989 .map(|c| (c.accuracy_proxy, c.estimated_latency_ms))
1990 .collect()
1991 }
1992}
1993
1994#[cfg(test)]
1999mod tests;