Skip to main content

tenflowers_neural/edge_optimization/
mod.rs

1//! Edge & Mobile ML Optimization
2//!
3//! Production-grade tensor decomposition, quantization, hardware-aware NAS,
4//! dynamic-width inference, integer arithmetic, and memory-budget planning
5//! for deploying models on resource-constrained edge / mobile devices.
6//!
7//! # Key Components
8//!
9//! * [`EoCpDecomposition`] --- Canonical Polyadic (CP) decomposition via ALS
10//! * [`EoTuckerDecomposition`] --- Tucker decomposition via HOSVD (power-iteration SVD)
11//! * [`TtDecomposition`] --- Tensor-Train (TT-SVD) decomposition
12//! * [`CodebookQuantization`] --- Weight sharing via k-means codebook (Lloyd)
13//! * [`ProductQuantization`] --- Sub-vector PQ with asymmetric distance computation
14//! * [`HardwareProfile`] / [`HardwareAwareSearch`] --- NAS under latency/memory budgets
15//! * [`DynamicWidthNetwork`] --- Slimmable runtime width adaptation
16//! * [`IntegerLinear`] --- Integer-only (Q-format fixed-point) inference layer
17//! * [`MemoryBudgetAllocator`] --- Activation-checkpointing & operator-fusion planner
18//! * [`EdgeMetrics`] / [`EdgeReport`] --- Compression / efficiency diagnostics
19
20use scirs2_core::random::{rngs::StdRng, Rng, SeedableRng};
21use scirs2_core::RngExt;
22use tenflowers_core::{Result, TensorError};
23
24mod helpers;
25use helpers::*;
26
27// ============================================================================
28// 1. EoCpDecomposition
29// ============================================================================
30
31/// Rank-1 factors from Canonical Polyadic decomposition: T ~ sum_r  a_r (x) b_r (x) c_r.
32#[derive(Debug, Clone)]
33pub struct CpFactors {
34    /// Factor matrix A (I x R).
35    pub factor_a: Vec<Vec<f64>>,
36    /// Factor matrix B (J x R).
37    pub factor_b: Vec<Vec<f64>>,
38    /// Factor matrix C (K x R).
39    pub factor_c: Vec<Vec<f64>>,
40    /// Component weights (length R).
41    pub lambdas: Vec<f64>,
42    /// Approximation error (relative Frobenius).
43    pub approx_error: f64,
44    /// Number of ALS iterations actually performed.
45    pub iterations: usize,
46}
47
48/// Canonical Polyadic (CP / CANDECOMP-PARAFAC) decomposition via ALS.
49///
50/// Factorizes a 3-D tensor  T(I x J x K)  into a sum of `rank` rank-1 terms:
51///     T ~ sum_{r=1}^{R}  lambda_r  *  a_r (x) b_r (x) c_r
52///
53/// where (x) denotes outer product.
54#[derive(Debug, Clone)]
55pub struct EoCpDecomposition {
56    pub rank: usize,
57    pub max_iters: usize,
58    pub tol: f64,
59    pub seed: u64,
60}
61
62impl EoCpDecomposition {
63    /// Create a new CP decomposition solver.
64    pub fn new(rank: usize, max_iters: usize) -> Self {
65        Self {
66            rank,
67            max_iters,
68            tol: 1e-8,
69            seed: 42,
70        }
71    }
72
73    /// Set convergence tolerance.
74    pub fn with_tol(mut self, tol: f64) -> Self {
75        self.tol = tol;
76        self
77    }
78
79    /// Decompose a 3-D tensor stored in row-major order.
80    /// `tensor` has shape (dim_i, dim_j, dim_k).
81    pub fn decompose(
82        &self,
83        tensor: &[f64],
84        dim_i: usize,
85        dim_j: usize,
86        dim_k: usize,
87    ) -> Result<CpFactors> {
88        let total = dim_i * dim_j * dim_k;
89        if tensor.len() != total {
90            return Err(TensorError::compute_error_simple(format!(
91                "CP decomposition: tensor length {} != {}*{}*{} = {}",
92                tensor.len(),
93                dim_i,
94                dim_j,
95                dim_k,
96                total,
97            )));
98        }
99
100        let r = self.rank;
101        let mut rng = StdRng::seed_from_u64(self.seed);
102
103        // Initialize factor matrices randomly
104        let mut a: Vec<Vec<f64>> = (0..dim_i)
105            .map(|_| (0..r).map(|_| rng.random_range(-1.0..1.0)).collect())
106            .collect();
107        let mut b: Vec<Vec<f64>> = (0..dim_j)
108            .map(|_| (0..r).map(|_| rng.random_range(-1.0..1.0)).collect())
109            .collect();
110        let mut c: Vec<Vec<f64>> = (0..dim_k)
111            .map(|_| (0..r).map(|_| rng.random_range(-1.0..1.0)).collect())
112            .collect();
113
114        let tensor_norm = frobenius(tensor);
115        let mut prev_error = f64::MAX;
116        let mut iterations = 0;
117
118        for iter in 0..self.max_iters {
119            iterations = iter + 1;
120
121            // Mode-0 unfolding: X_(0) is (I x JK), factor update: A = X_(0) * (C kr B) * pinv(...)
122            // A <- X_(0) (C kr B) [(B^T B * C^T C)]^{-1}
123            let kr_cb = khatri_rao(&c, &b);
124            let unfold_0 = self.mode_unfold(tensor, dim_i, dim_j, dim_k, 0);
125            a = self.update_factor(&unfold_0, &kr_cb, r)?;
126
127            // B <- X_(1) (C kr A) [(A^T A * C^T C)]^{-1}
128            let kr_ca = khatri_rao(&c, &a);
129            let unfold_1 = self.mode_unfold(tensor, dim_i, dim_j, dim_k, 1);
130            b = self.update_factor(&unfold_1, &kr_ca, r)?;
131
132            // C <- X_(2) (B kr A) [(A^T A * B^T B)]^{-1}
133            let kr_ba = khatri_rao(&b, &a);
134            let unfold_2 = self.mode_unfold(tensor, dim_i, dim_j, dim_k, 2);
135            c = self.update_factor(&unfold_2, &kr_ba, r)?;
136
137            // Check convergence via reconstruction error
138            let recon = self.reconstruct_flat(&a, &b, &c, &vec![1.0; r], dim_i, dim_j, dim_k);
139            let err_vec: Vec<f64> = tensor
140                .iter()
141                .zip(recon.iter())
142                .map(|(t, r)| t - r)
143                .collect();
144            let error = frobenius(&err_vec) / (tensor_norm + 1e-30);
145
146            if (prev_error - error).abs() < self.tol {
147                break;
148            }
149            prev_error = error;
150        }
151
152        // Normalize columns and extract lambdas
153        let mut lambdas = vec![1.0_f64; r];
154        for col in 0..r {
155            let na: f64 = a.iter().map(|row| row[col] * row[col]).sum::<f64>().sqrt();
156            let nb: f64 = b.iter().map(|row| row[col] * row[col]).sum::<f64>().sqrt();
157            let nc: f64 = c.iter().map(|row| row[col] * row[col]).sum::<f64>().sqrt();
158            lambdas[col] = na * nb * nc;
159            if na > 1e-15 {
160                for row in &mut a {
161                    row[col] /= na;
162                }
163            }
164            if nb > 1e-15 {
165                for row in &mut b {
166                    row[col] /= nb;
167                }
168            }
169            if nc > 1e-15 {
170                for row in &mut c {
171                    row[col] /= nc;
172                }
173            }
174        }
175
176        let recon = self.reconstruct_flat(&a, &b, &c, &lambdas, dim_i, dim_j, dim_k);
177        let err_vec: Vec<f64> = tensor
178            .iter()
179            .zip(recon.iter())
180            .map(|(t, r)| t - r)
181            .collect();
182        let approx_error = frobenius(&err_vec) / (tensor_norm + 1e-30);
183
184        Ok(CpFactors {
185            factor_a: a,
186            factor_b: b,
187            factor_c: c,
188            lambdas,
189            approx_error,
190            iterations,
191        })
192    }
193
194    /// Reconstruct the full tensor from CP factors.
195    pub fn reconstruct(factors: &CpFactors, dim_i: usize, dim_j: usize, dim_k: usize) -> Vec<f64> {
196        let r = factors.lambdas.len();
197        let mut out = vec![0.0_f64; dim_i * dim_j * dim_k];
198        for comp in 0..r {
199            let lam = factors.lambdas[comp];
200            for i in 0..dim_i {
201                for j in 0..dim_j {
202                    for k in 0..dim_k {
203                        out[i * dim_j * dim_k + j * dim_k + k] += lam
204                            * factors.factor_a[i][comp]
205                            * factors.factor_b[j][comp]
206                            * factors.factor_c[k][comp];
207                    }
208                }
209            }
210        }
211        out
212    }
213
214    // -- helpers --
215
216    fn mode_unfold(
217        &self,
218        tensor: &[f64],
219        di: usize,
220        dj: usize,
221        dk: usize,
222        mode: usize,
223    ) -> Vec<Vec<f64>> {
224        match mode {
225            0 => {
226                // (I x JK)
227                (0..di)
228                    .map(|i| {
229                        let mut row = Vec::with_capacity(dj * dk);
230                        for j in 0..dj {
231                            for k in 0..dk {
232                                row.push(tensor[i * dj * dk + j * dk + k]);
233                            }
234                        }
235                        row
236                    })
237                    .collect()
238            }
239            1 => {
240                // (J x IK)
241                (0..dj)
242                    .map(|j| {
243                        let mut row = Vec::with_capacity(di * dk);
244                        for i in 0..di {
245                            for k in 0..dk {
246                                row.push(tensor[i * dj * dk + j * dk + k]);
247                            }
248                        }
249                        row
250                    })
251                    .collect()
252            }
253            _ => {
254                // mode 2: (K x IJ)
255                (0..dk)
256                    .map(|k| {
257                        let mut row = Vec::with_capacity(di * dj);
258                        for i in 0..di {
259                            for j in 0..dj {
260                                row.push(tensor[i * dj * dk + j * dk + k]);
261                            }
262                        }
263                        row
264                    })
265                    .collect()
266            }
267        }
268    }
269
270    fn update_factor(
271        &self,
272        unfold: &[Vec<f64>],
273        kr: &[Vec<f64>],
274        _rank: usize,
275    ) -> Result<Vec<Vec<f64>>> {
276        // factor = unfold * kr * pinv(kr^T kr)
277        let product = mat_mul(unfold, kr);
278        let krtk = mat_mul(&mat_t(kr), kr);
279        let n = krtk.len();
280        // Invert krtk (small RxR) via Gauss-Jordan
281        let mut aug: Vec<Vec<f64>> = krtk
282            .iter()
283            .enumerate()
284            .map(|(i, row)| {
285                let mut r = row.clone();
286                for j in 0..n {
287                    r.push(if i == j { 1.0 } else { 0.0 });
288                }
289                r
290            })
291            .collect();
292        for col in 0..n {
293            let mut max_row = col;
294            let mut max_val = aug[col][col].abs();
295            for row in (col + 1)..n {
296                let v = aug[row][col].abs();
297                if v > max_val {
298                    max_val = v;
299                    max_row = row;
300                }
301            }
302            if max_val < 1e-14 {
303                // Add small regularization instead of failing
304                aug[col][col] += 1e-10;
305            }
306            aug.swap(col, max_row);
307            let pivot = aug[col][col];
308            for j in 0..(2 * n) {
309                aug[col][j] /= pivot;
310            }
311            for row in 0..n {
312                if row == col {
313                    continue;
314                }
315                let factor = aug[row][col];
316                for j in 0..(2 * n) {
317                    aug[row][j] -= factor * aug[col][j];
318                }
319            }
320        }
321        let inv: Vec<Vec<f64>> = aug.iter().map(|r| r[n..].to_vec()).collect();
322        Ok(mat_mul(&product, &inv))
323    }
324
325    fn reconstruct_flat(
326        &self,
327        a: &[Vec<f64>],
328        b: &[Vec<f64>],
329        c: &[Vec<f64>],
330        lambdas: &[f64],
331        di: usize,
332        dj: usize,
333        dk: usize,
334    ) -> Vec<f64> {
335        let r = lambdas.len();
336        let mut out = vec![0.0_f64; di * dj * dk];
337        for comp in 0..r {
338            let lam = lambdas[comp];
339            for i in 0..di {
340                for j in 0..dj {
341                    for k in 0..dk {
342                        out[i * dj * dk + j * dk + k] += lam * a[i][comp] * b[j][comp] * c[k][comp];
343                    }
344                }
345            }
346        }
347        out
348    }
349}
350
351// ============================================================================
352// 2. EoTuckerDecomposition
353// ============================================================================
354
355/// Result of Tucker decomposition: core tensor G and factor matrices U1, U2, U3.
356#[derive(Debug, Clone)]
357pub struct TuckerFactors {
358    /// Core tensor G of shape (r1 x r2 x r3), stored flat row-major.
359    pub core: Vec<f64>,
360    pub core_shape: (usize, usize, usize),
361    /// Factor matrices U_n  (dim_n x r_n).
362    pub factors: Vec<Vec<Vec<f64>>>,
363    /// Compression ratio (original_size / compressed_size).
364    pub compression_ratio: f64,
365    /// Relative Frobenius approximation error.
366    pub approx_error: f64,
367}
368
369/// Tucker decomposition via truncated HOSVD (Higher-Order SVD).
370///
371/// T ~ G x_1 U1 x_2 U2 x_3 U3
372/// where G is a small core tensor of shape (r1, r2, r3).
373#[derive(Debug, Clone)]
374pub struct EoTuckerDecomposition {
375    /// Target ranks for each mode.
376    pub ranks: (usize, usize, usize),
377    pub svd_iters: usize,
378    pub seed: u64,
379}
380
381impl EoTuckerDecomposition {
382    pub fn new(ranks: (usize, usize, usize)) -> Self {
383        Self {
384            ranks,
385            svd_iters: 50,
386            seed: 42,
387        }
388    }
389
390    /// Decompose a 3-D tensor T(dim_i, dim_j, dim_k) via HOSVD.
391    pub fn decompose(
392        &self,
393        tensor: &[f64],
394        dim_i: usize,
395        dim_j: usize,
396        dim_k: usize,
397    ) -> Result<TuckerFactors> {
398        let total = dim_i * dim_j * dim_k;
399        if tensor.len() != total {
400            return Err(TensorError::compute_error_simple(format!(
401                "Tucker decomposition: tensor length {} != {}",
402                tensor.len(),
403                total,
404            )));
405        }
406
407        let (r1, r2, r3) = self.ranks;
408        let tensor_norm = frobenius(tensor);
409
410        // Mode-0 unfolding -> truncated SVD -> U1
411        let unfold_0 = self.mode_unfold(tensor, dim_i, dim_j, dim_k, 0);
412        let (u1, _s1, _vt1) = truncated_svd(&unfold_0, r1, self.svd_iters, self.seed)?;
413
414        // Mode-1 unfolding -> truncated SVD -> U2
415        let unfold_1 = self.mode_unfold(tensor, dim_i, dim_j, dim_k, 1);
416        let (u2, _s2, _vt2) = truncated_svd(&unfold_1, r2, self.svd_iters, self.seed + 1)?;
417
418        // Mode-2 unfolding -> truncated SVD -> U3
419        let unfold_2 = self.mode_unfold(tensor, dim_i, dim_j, dim_k, 2);
420        let (u3, _s3, _vt3) = truncated_svd(&unfold_2, r3, self.svd_iters, self.seed + 2)?;
421
422        // Core tensor: G = T x_1 U1^T x_2 U2^T x_3 U3^T
423        let core = self.compute_core(tensor, dim_i, dim_j, dim_k, &u1, &u2, &u3);
424        let actual_r1 = u1.first().map_or(0, |r| r.len());
425        let actual_r2 = u2.first().map_or(0, |r| r.len());
426        let actual_r3 = u3.first().map_or(0, |r| r.len());
427
428        // Compression ratio
429        let original_size = total;
430        let compressed_size = actual_r1 * actual_r2 * actual_r3
431            + dim_i * actual_r1
432            + dim_j * actual_r2
433            + dim_k * actual_r3;
434        let compression_ratio = if compressed_size > 0 {
435            original_size as f64 / compressed_size as f64
436        } else {
437            0.0
438        };
439
440        // Reconstruction error
441        let recon = Self::reconstruct_from_parts(
442            &core, actual_r1, actual_r2, actual_r3, &u1, &u2, &u3, dim_i, dim_j, dim_k,
443        );
444        let err_vec: Vec<f64> = tensor
445            .iter()
446            .zip(recon.iter())
447            .map(|(t, r)| t - r)
448            .collect();
449        let approx_error = frobenius(&err_vec) / (tensor_norm + 1e-30);
450
451        Ok(TuckerFactors {
452            core,
453            core_shape: (actual_r1, actual_r2, actual_r3),
454            factors: vec![u1, u2, u3],
455            compression_ratio,
456            approx_error,
457        })
458    }
459
460    /// Reconstruct full tensor from Tucker factors.
461    pub fn reconstruct(
462        factors: &TuckerFactors,
463        dim_i: usize,
464        dim_j: usize,
465        dim_k: usize,
466    ) -> Vec<f64> {
467        let (r1, r2, r3) = factors.core_shape;
468        Self::reconstruct_from_parts(
469            &factors.core,
470            r1,
471            r2,
472            r3,
473            &factors.factors[0],
474            &factors.factors[1],
475            &factors.factors[2],
476            dim_i,
477            dim_j,
478            dim_k,
479        )
480    }
481
482    // -- helpers --
483
484    fn mode_unfold(
485        &self,
486        tensor: &[f64],
487        di: usize,
488        dj: usize,
489        dk: usize,
490        mode: usize,
491    ) -> Vec<Vec<f64>> {
492        match mode {
493            0 => (0..di)
494                .map(|i| {
495                    let mut row = Vec::with_capacity(dj * dk);
496                    for j in 0..dj {
497                        for k in 0..dk {
498                            row.push(tensor[i * dj * dk + j * dk + k]);
499                        }
500                    }
501                    row
502                })
503                .collect(),
504            1 => (0..dj)
505                .map(|j| {
506                    let mut row = Vec::with_capacity(di * dk);
507                    for i in 0..di {
508                        for k in 0..dk {
509                            row.push(tensor[i * dj * dk + j * dk + k]);
510                        }
511                    }
512                    row
513                })
514                .collect(),
515            _ => (0..dk)
516                .map(|k| {
517                    let mut row = Vec::with_capacity(di * dj);
518                    for i in 0..di {
519                        for j in 0..dj {
520                            row.push(tensor[i * dj * dk + j * dk + k]);
521                        }
522                    }
523                    row
524                })
525                .collect(),
526        }
527    }
528
529    fn compute_core(
530        &self,
531        tensor: &[f64],
532        di: usize,
533        dj: usize,
534        dk: usize,
535        u1: &[Vec<f64>],
536        u2: &[Vec<f64>],
537        u3: &[Vec<f64>],
538    ) -> Vec<f64> {
539        let r1 = u1.first().map_or(0, |r| r.len());
540        let r2 = u2.first().map_or(0, |r| r.len());
541        let r3 = u3.first().map_or(0, |r| r.len());
542        let mut core = vec![0.0_f64; r1 * r2 * r3];
543        // G(a,b,c) = sum_{i,j,k} T(i,j,k) * U1(i,a) * U2(j,b) * U3(k,c)
544        for i in 0..di {
545            for j in 0..dj {
546                let tij_base = i * dj * dk + j * dk;
547                for k in 0..dk {
548                    let val = tensor[tij_base + k];
549                    if val.abs() < 1e-30 {
550                        continue;
551                    }
552                    for a in 0..r1 {
553                        let u1_ia = u1[i][a];
554                        if u1_ia.abs() < 1e-30 {
555                            continue;
556                        }
557                        for b in 0..r2 {
558                            let u2_jb = u2[j][b];
559                            if u2_jb.abs() < 1e-30 {
560                                continue;
561                            }
562                            for c in 0..r3 {
563                                core[a * r2 * r3 + b * r3 + c] += val * u1_ia * u2_jb * u3[k][c];
564                            }
565                        }
566                    }
567                }
568            }
569        }
570        core
571    }
572
573    fn reconstruct_from_parts(
574        core: &[f64],
575        r1: usize,
576        r2: usize,
577        r3: usize,
578        u1: &[Vec<f64>],
579        u2: &[Vec<f64>],
580        u3: &[Vec<f64>],
581        di: usize,
582        dj: usize,
583        dk: usize,
584    ) -> Vec<f64> {
585        let mut out = vec![0.0_f64; di * dj * dk];
586        // T(i,j,k) = sum_{a,b,c} G(a,b,c) * U1(i,a) * U2(j,b) * U3(k,c)
587        for a in 0..r1 {
588            for b in 0..r2 {
589                for c in 0..r3 {
590                    let g_abc = core[a * r2 * r3 + b * r3 + c];
591                    if g_abc.abs() < 1e-30 {
592                        continue;
593                    }
594                    for i in 0..di {
595                        let ga_u1 = g_abc * u1[i][a];
596                        if ga_u1.abs() < 1e-30 {
597                            continue;
598                        }
599                        for j in 0..dj {
600                            let gau2 = ga_u1 * u2[j][b];
601                            if gau2.abs() < 1e-30 {
602                                continue;
603                            }
604                            for k in 0..dk {
605                                out[i * dj * dk + j * dk + k] += gau2 * u3[k][c];
606                            }
607                        }
608                    }
609                }
610            }
611        }
612        out
613    }
614}
615
616// ============================================================================
617// 3. TtDecomposition  (Tensor-Train)
618// ============================================================================
619
620/// A single TT core: shape (r_{k-1}, n_k, r_k) stored row-major.
621#[derive(Debug, Clone)]
622pub struct TtCore {
623    pub data: Vec<f64>,
624    pub shape: (usize, usize, usize), // (r_left, n, r_right)
625}
626
627/// Result of Tensor-Train decomposition.
628#[derive(Debug, Clone)]
629pub struct TtFactors {
630    pub cores: Vec<TtCore>,
631    /// Original tensor shape.
632    pub original_shape: Vec<usize>,
633    /// Relative Frobenius approximation error.
634    pub approx_error: f64,
635}
636
637/// Tensor-Train (TT) decomposition via TT-SVD.
638///
639/// T(i1, ..., id) = G1(:,i1,:) * G2(:,i2,:) * ... * Gd(:,id,:)
640#[derive(Debug, Clone)]
641pub struct TtDecomposition {
642    pub max_rank: usize,
643    pub tol: f64,
644    pub svd_iters: usize,
645    pub seed: u64,
646}
647
648impl TtDecomposition {
649    pub fn new(max_rank: usize) -> Self {
650        Self {
651            max_rank,
652            tol: 1e-8,
653            svd_iters: 50,
654            seed: 42,
655        }
656    }
657
658    /// Decompose a tensor with given shape into TT format.
659    /// `tensor` is stored flat in row-major order for the given shape.
660    pub fn decompose(&self, tensor: &[f64], shape: &[usize]) -> Result<TtFactors> {
661        let total: usize = shape.iter().product();
662        if tensor.len() != total {
663            return Err(TensorError::compute_error_simple(format!(
664                "TT decomposition: tensor length {} != product of shape {:?} = {}",
665                tensor.len(),
666                shape,
667                total,
668            )));
669        }
670        if shape.len() < 2 {
671            return Err(TensorError::compute_error_simple(
672                "TT decomposition requires at least 2 dimensions".to_string(),
673            ));
674        }
675
676        let tensor_norm = frobenius(tensor);
677        let d = shape.len();
678        let mut cores: Vec<TtCore> = Vec::with_capacity(d);
679        let mut c = tensor.to_vec();
680        let mut r_prev = 1_usize;
681
682        // Remaining size of the "tail" dimensions
683        let mut remaining: usize = total;
684
685        for k in 0..(d - 1) {
686            let n_k = shape[k];
687            remaining /= n_k;
688            // Reshape c as (r_prev * n_k) x (remaining)
689            let rows = r_prev * n_k;
690            let cols = remaining;
691            let mat: Vec<Vec<f64>> = (0..rows)
692                .map(|i| (0..cols).map(|j| c[i * cols + j]).collect())
693                .collect();
694
695            let rank = self.max_rank.min(rows).min(cols);
696            let (u, s, vt) = truncated_svd(&mat, rank, self.svd_iters, self.seed + k as u64)?;
697
698            let actual_rank = s.len();
699            // Core_k: (r_prev, n_k, actual_rank)
700            let mut core_data = vec![0.0_f64; r_prev * n_k * actual_rank];
701            for i in 0..rows {
702                for r in 0..actual_rank {
703                    core_data[i * actual_rank + r] =
704                        u.get(i).and_then(|row| row.get(r).copied()).unwrap_or(0.0);
705                }
706            }
707            cores.push(TtCore {
708                data: core_data,
709                shape: (r_prev, n_k, actual_rank),
710            });
711
712            // c <- diag(s) * Vt  (actual_rank x remaining)
713            let mut new_c = vec![0.0_f64; actual_rank * cols];
714            for r in 0..actual_rank {
715                for j in 0..cols {
716                    new_c[r * cols + j] =
717                        s[r] * vt.get(r).and_then(|row| row.get(j).copied()).unwrap_or(0.0);
718                }
719            }
720            c = new_c;
721            r_prev = actual_rank;
722        }
723
724        // Last core: (r_prev, n_{d-1}, 1)
725        let n_last = shape[d - 1];
726        let mut last_core_data = vec![0.0_f64; r_prev * n_last];
727        let copy_len = c.len().min(r_prev * n_last);
728        last_core_data[..copy_len].copy_from_slice(&c[..copy_len]);
729        cores.push(TtCore {
730            data: last_core_data,
731            shape: (r_prev, n_last, 1),
732        });
733
734        // Compute reconstruction error
735        let recon = Self::reconstruct_flat(&cores, shape);
736        let err_vec: Vec<f64> = tensor
737            .iter()
738            .zip(recon.iter())
739            .map(|(t, r)| t - r)
740            .collect();
741        let approx_error = frobenius(&err_vec) / (tensor_norm + 1e-30);
742
743        Ok(TtFactors {
744            cores,
745            original_shape: shape.to_vec(),
746            approx_error,
747        })
748    }
749
750    /// Reconstruct a full tensor from TT cores.
751    pub fn reconstruct(factors: &TtFactors) -> Vec<f64> {
752        Self::reconstruct_flat(&factors.cores, &factors.original_shape)
753    }
754
755    /// TT-rounding: reduce ranks by re-decomposing each core.
756    pub fn round(&self, factors: &TtFactors, new_max_rank: usize) -> Result<TtFactors> {
757        let full = Self::reconstruct_flat(&factors.cores, &factors.original_shape);
758        let mut dec = self.clone();
759        dec.max_rank = new_max_rank;
760        dec.decompose(&full, &factors.original_shape)
761    }
762
763    fn reconstruct_flat(cores: &[TtCore], shape: &[usize]) -> Vec<f64> {
764        let total: usize = shape.iter().product();
765        let d = shape.len();
766        let mut result = vec![0.0_f64; total];
767
768        // For each multi-index, multiply chain of core slices
769        let mut indices = vec![0_usize; d];
770        for flat_idx in 0..total {
771            // Compute multi-index from flat_idx
772            let mut rem = flat_idx;
773            for k in (0..d).rev() {
774                indices[k] = rem % shape[k];
775                rem /= shape[k];
776            }
777
778            // Product of slices: G1(:,i1,:) * G2(:,i2,:) * ... * Gd(:,id,:)
779            // Start with [1] (1x1 matrix)
780            let mut vec_cur: Vec<f64> = vec![1.0];
781            let mut cur_cols = 1_usize;
782
783            for (k, core) in cores.iter().enumerate() {
784                let (r_left, _n_k, r_right) = core.shape;
785                let ik = indices[k];
786                // Extract slice core[:, ik, :] which is (r_left x r_right)
787                // core.data layout: (r_left * n_k * r_right) row-major for (r_left, n_k, r_right)
788                // Element (a, ik, b) = data[a * n_k * r_right + ik * r_right + b]
789                let mut new_vec = vec![0.0_f64; r_right];
790                // vec_cur has cur_cols entries, should match r_left
791                debug_assert_eq!(cur_cols, r_left);
792                for a in 0..r_left {
793                    let v_a = vec_cur[a];
794                    if v_a.abs() < 1e-30 {
795                        continue;
796                    }
797                    for b in 0..r_right {
798                        let core_val = core.data[a * cores[k].shape.1 * r_right + ik * r_right + b];
799                        new_vec[b] += v_a * core_val;
800                    }
801                }
802                vec_cur = new_vec;
803                cur_cols = r_right;
804            }
805
806            result[flat_idx] = vec_cur[0];
807        }
808        result
809    }
810}
811
812// ============================================================================
813// 4. CodebookQuantization
814// ============================================================================
815
816/// Result of codebook quantization.
817#[derive(Debug, Clone)]
818pub struct CodebookResult {
819    /// Centroids (codebook), length = n_clusters.
820    pub codebook: Vec<f64>,
821    /// Index per weight, pointing into codebook.
822    pub indices: Vec<usize>,
823    /// Compression ratio achieved.
824    pub compression_ratio: f64,
825    /// Mean-squared quantization error.
826    pub mse: f64,
827}
828
829/// Weight sharing via k-means codebook (Lloyd's algorithm).
830///
831/// Each weight is replaced by the index of its nearest centroid.
832/// Compression = original_bits / (codebook_bits + index_bits).
833#[derive(Debug, Clone)]
834pub struct CodebookQuantization {
835    pub n_clusters: usize,
836    pub max_iters: usize,
837    pub seed: u64,
838}
839
840impl CodebookQuantization {
841    pub fn new(n_clusters: usize) -> Self {
842        Self {
843            n_clusters,
844            max_iters: 100,
845            seed: 42,
846        }
847    }
848
849    /// Quantize a weight vector into a codebook + indices.
850    pub fn quantize(&self, weights: &[f64]) -> Result<CodebookResult> {
851        if weights.is_empty() {
852            return Err(TensorError::compute_error_simple(
853                "CodebookQuantization: empty weight vector".to_string(),
854            ));
855        }
856        let k = self.n_clusters.min(weights.len());
857        if k == 0 {
858            return Err(TensorError::compute_error_simple(
859                "CodebookQuantization: n_clusters must be > 0".to_string(),
860            ));
861        }
862
863        let mut rng = StdRng::seed_from_u64(self.seed);
864
865        // Initialize centroids via k-means++ style
866        let mut centroids: Vec<f64> = Vec::with_capacity(k);
867        // Pick first centroid randomly
868        let first_idx = rng.random_range(0..weights.len());
869        centroids.push(weights[first_idx]);
870
871        for _ in 1..k {
872            let mut dist_sq: Vec<f64> = weights
873                .iter()
874                .map(|w| {
875                    centroids
876                        .iter()
877                        .map(|c| (w - c) * (w - c))
878                        .fold(f64::MAX, f64::min)
879                })
880                .collect();
881            let total: f64 = dist_sq.iter().sum();
882            if total < 1e-30 {
883                // All remaining weights are already close to existing centroids
884                centroids.push(weights[rng.random_range(0..weights.len())]);
885                continue;
886            }
887            // Normalize to probabilities
888            for d in &mut dist_sq {
889                *d /= total;
890            }
891            let r: f64 = rng.random_range(0.0..1.0);
892            let mut cumsum = 0.0;
893            let mut chosen = 0;
894            for (i, &d) in dist_sq.iter().enumerate() {
895                cumsum += d;
896                if cumsum >= r {
897                    chosen = i;
898                    break;
899                }
900            }
901            centroids.push(weights[chosen]);
902        }
903
904        // Lloyd's algorithm
905        let mut indices = vec![0_usize; weights.len()];
906        for _iter in 0..self.max_iters {
907            // Assignment step
908            let mut changed = false;
909            for (i, w) in weights.iter().enumerate() {
910                let mut best = 0;
911                let mut best_dist = f64::MAX;
912                for (c, centroid) in centroids.iter().enumerate() {
913                    let d = (w - centroid) * (w - centroid);
914                    if d < best_dist {
915                        best_dist = d;
916                        best = c;
917                    }
918                }
919                if indices[i] != best {
920                    changed = true;
921                    indices[i] = best;
922                }
923            }
924            if !changed {
925                break;
926            }
927
928            // Update step
929            let mut sums = vec![0.0_f64; k];
930            let mut counts = vec![0_usize; k];
931            for (i, w) in weights.iter().enumerate() {
932                sums[indices[i]] += w;
933                counts[indices[i]] += 1;
934            }
935            for c in 0..k {
936                if counts[c] > 0 {
937                    centroids[c] = sums[c] / counts[c] as f64;
938                }
939            }
940        }
941
942        // Compute MSE
943        let mse: f64 = weights
944            .iter()
945            .zip(indices.iter())
946            .map(|(w, &idx)| {
947                let d = w - centroids[idx];
948                d * d
949            })
950            .sum::<f64>()
951            / weights.len() as f64;
952
953        // Compression ratio: original 64 bits per weight  ->  codebook(k*64) + indices(n * log2(k))
954        let index_bits = (k as f64).log2().ceil().max(1.0);
955        let original_bits = weights.len() as f64 * 64.0;
956        let compressed_bits = k as f64 * 64.0 + weights.len() as f64 * index_bits;
957        let compression_ratio = if compressed_bits > 0.0 {
958            original_bits / compressed_bits
959        } else {
960            0.0
961        };
962
963        Ok(CodebookResult {
964            codebook: centroids,
965            indices,
966            compression_ratio,
967            mse,
968        })
969    }
970
971    /// Reconstruct weights from codebook + indices.
972    pub fn dequantize(codebook: &[f64], indices: &[usize]) -> Vec<f64> {
973        indices
974            .iter()
975            .map(|&idx| {
976                if idx < codebook.len() {
977                    codebook[idx]
978                } else {
979                    0.0
980                }
981            })
982            .collect()
983    }
984}
985
986// ============================================================================
987// 5. ProductQuantization
988// ============================================================================
989
990/// Encoded vectors: each vector is a sequence of M centroid indices.
991#[derive(Debug, Clone)]
992pub struct PqCodes {
993    /// codes\[i\]\[m\] = centroid index for sub-vector m of vector i.
994    pub codes: Vec<Vec<usize>>,
995    /// Codebooks: codebooks\[m\] is a (n_centroids x sub_dim) table.
996    pub codebooks: Vec<Vec<Vec<f64>>>,
997    pub n_subquantizers: usize,
998    pub sub_dim: usize,
999}
1000
1001/// Sub-vector Product Quantization (PQ).
1002///
1003/// Splits each vector into M sub-vectors and independently quantizes each
1004/// with k-means. Enables fast asymmetric distance computation (ADC).
1005#[derive(Debug, Clone)]
1006pub struct ProductQuantization {
1007    pub n_subquantizers: usize,
1008    pub n_centroids: usize,
1009    pub max_iters: usize,
1010    pub seed: u64,
1011}
1012
1013impl ProductQuantization {
1014    pub fn new(n_subquantizers: usize, n_centroids: usize) -> Self {
1015        Self {
1016            n_subquantizers,
1017            n_centroids,
1018            max_iters: 50,
1019            seed: 42,
1020        }
1021    }
1022
1023    /// Encode a batch of vectors.  Each vector has dimension `dim`.
1024    /// `vectors` is a flat array: vectors[i * dim + j].
1025    pub fn encode(&self, vectors: &[f64], n_vectors: usize, dim: usize) -> Result<PqCodes> {
1026        if vectors.len() != n_vectors * dim {
1027            return Err(TensorError::compute_error_simple(format!(
1028                "PQ encode: expected {} elements, got {}",
1029                n_vectors * dim,
1030                vectors.len(),
1031            )));
1032        }
1033        if dim % self.n_subquantizers != 0 {
1034            return Err(TensorError::compute_error_simple(format!(
1035                "PQ encode: dim {} not divisible by n_subquantizers {}",
1036                dim, self.n_subquantizers,
1037            )));
1038        }
1039        let sub_dim = dim / self.n_subquantizers;
1040        let m_count = self.n_subquantizers;
1041        let k = self.n_centroids.min(n_vectors);
1042
1043        let mut codebooks: Vec<Vec<Vec<f64>>> = Vec::with_capacity(m_count);
1044        let mut codes: Vec<Vec<usize>> = vec![vec![0_usize; m_count]; n_vectors];
1045        let mut rng = StdRng::seed_from_u64(self.seed);
1046
1047        for m in 0..m_count {
1048            let offset = m * sub_dim;
1049            // Extract sub-vectors for this partition
1050            let sub_vecs: Vec<Vec<f64>> = (0..n_vectors)
1051                .map(|i| {
1052                    (0..sub_dim)
1053                        .map(|d| vectors[i * dim + offset + d])
1054                        .collect()
1055                })
1056                .collect();
1057
1058            // k-means on sub-vectors
1059            let mut centroids: Vec<Vec<f64>> = (0..k)
1060                .map(|_| {
1061                    let idx = rng.random_range(0..n_vectors);
1062                    sub_vecs[idx].clone()
1063                })
1064                .collect();
1065
1066            let mut assignments = vec![0_usize; n_vectors];
1067            for _iter in 0..self.max_iters {
1068                // Assign
1069                let mut changed = false;
1070                for i in 0..n_vectors {
1071                    let mut best = 0;
1072                    let mut best_dist = f64::MAX;
1073                    for (c, centroid) in centroids.iter().enumerate() {
1074                        let d: f64 = sub_vecs[i]
1075                            .iter()
1076                            .zip(centroid.iter())
1077                            .map(|(a, b)| (a - b) * (a - b))
1078                            .sum();
1079                        if d < best_dist {
1080                            best_dist = d;
1081                            best = c;
1082                        }
1083                    }
1084                    if assignments[i] != best {
1085                        changed = true;
1086                        assignments[i] = best;
1087                    }
1088                }
1089                if !changed {
1090                    break;
1091                }
1092                // Update
1093                let mut sums = vec![vec![0.0_f64; sub_dim]; k];
1094                let mut counts = vec![0_usize; k];
1095                for (i, &a) in assignments.iter().enumerate() {
1096                    for d in 0..sub_dim {
1097                        sums[a][d] += sub_vecs[i][d];
1098                    }
1099                    counts[a] += 1;
1100                }
1101                for c in 0..k {
1102                    if counts[c] > 0 {
1103                        for d in 0..sub_dim {
1104                            centroids[c][d] = sums[c][d] / counts[c] as f64;
1105                        }
1106                    }
1107                }
1108            }
1109
1110            codebooks.push(centroids);
1111            for (i, &a) in assignments.iter().enumerate() {
1112                codes[i][m] = a;
1113            }
1114        }
1115
1116        Ok(PqCodes {
1117            codes,
1118            codebooks,
1119            n_subquantizers: m_count,
1120            sub_dim,
1121        })
1122    }
1123
1124    /// Asymmetric Distance Computation (ADC): compute L2 distance from `query` to every
1125    /// encoded vector.  Returns distances of length n_vectors.
1126    pub fn search_adc(&self, query: &[f64], pq_codes: &PqCodes) -> Result<Vec<f64>> {
1127        let dim = pq_codes.n_subquantizers * pq_codes.sub_dim;
1128        if query.len() != dim {
1129            return Err(TensorError::compute_error_simple(format!(
1130                "PQ search: query dim {} != expected {}",
1131                query.len(),
1132                dim,
1133            )));
1134        }
1135
1136        let m_count = pq_codes.n_subquantizers;
1137        let sub_dim = pq_codes.sub_dim;
1138
1139        // Precompute distance tables: dist_table[m][c] = ||q_m - centroid_{m,c}||^2
1140        let mut dist_table: Vec<Vec<f64>> = Vec::with_capacity(m_count);
1141        for m in 0..m_count {
1142            let offset = m * sub_dim;
1143            let q_sub: Vec<f64> = (0..sub_dim).map(|d| query[offset + d]).collect();
1144            let table: Vec<f64> = pq_codes.codebooks[m]
1145                .iter()
1146                .map(|c| {
1147                    q_sub
1148                        .iter()
1149                        .zip(c.iter())
1150                        .map(|(a, b)| (a - b) * (a - b))
1151                        .sum()
1152                })
1153                .collect();
1154            dist_table.push(table);
1155        }
1156
1157        // Compute distance for each vector
1158        let n_vectors = pq_codes.codes.len();
1159        let distances: Vec<f64> = (0..n_vectors)
1160            .map(|i| {
1161                (0..m_count)
1162                    .map(|m| {
1163                        let idx = pq_codes.codes[i][m];
1164                        if idx < dist_table[m].len() {
1165                            dist_table[m][idx]
1166                        } else {
1167                            0.0
1168                        }
1169                    })
1170                    .sum()
1171            })
1172            .collect();
1173
1174        Ok(distances)
1175    }
1176}
1177
1178// ============================================================================
1179// 6. HardwareProfile & HardwareAwareSearch
1180// ============================================================================
1181
1182/// Description of target hardware capabilities.
1183#[derive(Debug, Clone)]
1184pub struct HardwareProfile {
1185    /// Name of the device (e.g., "Cortex-M7", "Jetson Nano").
1186    pub name: String,
1187    /// Available RAM in bytes.
1188    pub memory_budget_bytes: usize,
1189    /// Compute budget in MFLOPS.
1190    pub compute_budget_mflops: f64,
1191    /// Target latency in milliseconds.
1192    pub latency_target_ms: f64,
1193    /// Supported integer bit-widths (e.g., [8, 16, 32]).
1194    pub supported_int_bits: Vec<usize>,
1195}
1196
1197impl HardwareProfile {
1198    pub fn new(name: &str, memory_bytes: usize, mflops: f64, latency_ms: f64) -> Self {
1199        Self {
1200            name: name.to_string(),
1201            memory_budget_bytes: memory_bytes,
1202            compute_budget_mflops: mflops,
1203            latency_target_ms: latency_ms,
1204            supported_int_bits: vec![8, 16, 32],
1205        }
1206    }
1207
1208    /// Predefined profile for a Cortex-M4 class MCU.
1209    pub fn cortex_m4() -> Self {
1210        Self::new("Cortex-M4", 256 * 1024, 100.0, 50.0)
1211    }
1212
1213    /// Predefined profile for a mobile phone (mid-range).
1214    pub fn mobile_midrange() -> Self {
1215        Self::new("Mobile-MidRange", 2 * 1024 * 1024 * 1024, 50_000.0, 20.0)
1216    }
1217
1218    /// Predefined profile for an edge GPU (Jetson Nano class).
1219    pub fn jetson_nano() -> Self {
1220        Self::new("Jetson-Nano", 4 * 1024 * 1024 * 1024, 472_000.0, 10.0)
1221    }
1222}
1223
1224/// A candidate architecture in the NAS search space.
1225#[derive(Debug, Clone)]
1226pub struct EoArchCandidate {
1227    /// Width multiplier (e.g., 0.25, 0.5, 0.75, 1.0).
1228    pub width_mult: f64,
1229    /// Depth multiplier.
1230    pub depth_mult: f64,
1231    /// Estimated latency in ms (FLOPs-proxy).
1232    pub estimated_latency_ms: f64,
1233    /// Estimated memory in bytes.
1234    pub estimated_memory_bytes: usize,
1235    /// Estimated FLOPs (millions).
1236    pub estimated_mflops: f64,
1237    /// Accuracy proxy (e.g., from a look-up table or a zero-cost proxy).
1238    pub accuracy_proxy: f64,
1239}
1240
1241/// Hardware-Aware Neural Architecture Search.
1242///
1243/// Generates candidate architectures with different width/depth multipliers,
1244/// estimates their latency and memory, and extracts a Pareto frontier
1245/// (accuracy vs. latency).
1246#[derive(Debug, Clone)]
1247pub struct HardwareAwareSearch {
1248    pub profile: HardwareProfile,
1249    /// Base model FLOPs at width=1.0, depth=1.0 (millions).
1250    pub base_mflops: f64,
1251    /// Base model parameters at width=1.0.
1252    pub base_params: usize,
1253    /// Width multipliers to try.
1254    pub width_mults: Vec<f64>,
1255    /// Depth multipliers to try.
1256    pub depth_mults: Vec<f64>,
1257    pub seed: u64,
1258}
1259
1260impl HardwareAwareSearch {
1261    pub fn new(profile: HardwareProfile, base_mflops: f64, base_params: usize) -> Self {
1262        Self {
1263            profile,
1264            base_mflops,
1265            base_params,
1266            width_mults: vec![0.25, 0.5, 0.75, 1.0],
1267            depth_mults: vec![0.5, 0.75, 1.0],
1268            seed: 42,
1269        }
1270    }
1271
1272    /// Generate all candidate architectures.
1273    pub fn generate_candidates(&self) -> Vec<EoArchCandidate> {
1274        let mut rng = StdRng::seed_from_u64(self.seed);
1275        let mut candidates = Vec::new();
1276
1277        for &wm in &self.width_mults {
1278            for &dm in &self.depth_mults {
1279                // FLOPs scale approximately as width^2 * depth (for conv-heavy models)
1280                let mflops = self.base_mflops * wm * wm * dm;
1281                // Memory scales as width * depth * param_size
1282                let mem = (self.base_params as f64 * wm * dm * 4.0) as usize; // 4 bytes/param (f32)
1283                                                                              // Latency proxy: FLOPs / compute_budget
1284                let latency = if self.profile.compute_budget_mflops > 0.0 {
1285                    mflops / self.profile.compute_budget_mflops * 1000.0 // ms
1286                } else {
1287                    f64::MAX
1288                };
1289                // Accuracy proxy: simple power-law heuristic + noise
1290                let accuracy_proxy =
1291                    (0.5 + 0.4 * (wm * dm).powf(0.3) + rng.random_range(-0.02..0.02))
1292                        .clamp(0.0, 1.0);
1293
1294                candidates.push(EoArchCandidate {
1295                    width_mult: wm,
1296                    depth_mult: dm,
1297                    estimated_latency_ms: latency,
1298                    estimated_memory_bytes: mem,
1299                    estimated_mflops: mflops,
1300                    accuracy_proxy,
1301                });
1302            }
1303        }
1304        candidates
1305    }
1306
1307    /// Filter candidates that satisfy hardware constraints.
1308    pub fn filter_feasible(&self, candidates: &[EoArchCandidate]) -> Vec<EoArchCandidate> {
1309        candidates
1310            .iter()
1311            .filter(|c| {
1312                c.estimated_latency_ms <= self.profile.latency_target_ms
1313                    && c.estimated_memory_bytes <= self.profile.memory_budget_bytes
1314            })
1315            .cloned()
1316            .collect()
1317    }
1318
1319    /// Extract Pareto frontier: candidates not dominated on (accuracy, latency).
1320    /// Returns indices into the candidates slice.
1321    pub fn pareto_frontier(&self, candidates: &[EoArchCandidate]) -> Vec<usize> {
1322        let n = candidates.len();
1323        let mut is_dominated = vec![false; n];
1324
1325        for i in 0..n {
1326            if is_dominated[i] {
1327                continue;
1328            }
1329            for j in 0..n {
1330                if i == j || is_dominated[j] {
1331                    continue;
1332                }
1333                // j dominates i if j has >= accuracy AND <= latency (and strictly better in one)
1334                let j_better_acc = candidates[j].accuracy_proxy >= candidates[i].accuracy_proxy;
1335                let j_better_lat =
1336                    candidates[j].estimated_latency_ms <= candidates[i].estimated_latency_ms;
1337                let j_strictly_better = candidates[j].accuracy_proxy > candidates[i].accuracy_proxy
1338                    || candidates[j].estimated_latency_ms < candidates[i].estimated_latency_ms;
1339                if j_better_acc && j_better_lat && j_strictly_better {
1340                    is_dominated[i] = true;
1341                    break;
1342                }
1343            }
1344        }
1345
1346        (0..n).filter(|&i| !is_dominated[i]).collect()
1347    }
1348
1349    /// Run full search: generate, filter, extract Pareto set.
1350    pub fn search(&self) -> Vec<EoArchCandidate> {
1351        let all = self.generate_candidates();
1352        let feasible = self.filter_feasible(&all);
1353        let pareto_idxs = self.pareto_frontier(&feasible);
1354        pareto_idxs.iter().map(|&i| feasible[i].clone()).collect()
1355    }
1356}
1357
1358// ============================================================================
1359// 7. DynamicWidthNetwork  (Slimmable network)
1360// ============================================================================
1361
1362/// A single linear layer that supports dynamic width selection.
1363#[derive(Debug, Clone)]
1364pub struct EoSlimmableLinear {
1365    /// Full weight matrix (out_features x in_features), row-major.
1366    pub weight: Vec<f64>,
1367    /// Full bias vector (out_features).
1368    pub bias: Vec<f64>,
1369    pub in_features: usize,
1370    pub out_features: usize,
1371}
1372
1373impl EoSlimmableLinear {
1374    pub fn new(in_features: usize, out_features: usize, seed: u64) -> Self {
1375        let mut rng = StdRng::seed_from_u64(seed);
1376        // Xavier uniform initialization
1377        let limit = (6.0 / (in_features + out_features) as f64).sqrt();
1378        let weight: Vec<f64> = (0..out_features * in_features)
1379            .map(|_| rng.random_range(-limit..limit))
1380            .collect();
1381        let bias = vec![0.0_f64; out_features];
1382        Self {
1383            weight,
1384            bias,
1385            in_features,
1386            out_features,
1387        }
1388    }
1389
1390    /// Forward pass using only the first `ceil(out * width_mult)` output channels
1391    /// and first `ceil(in * width_mult)` input channels.
1392    pub fn forward_at_width(&self, input: &[f64], width_mult: f64) -> Result<Vec<f64>> {
1393        let active_in = ((self.in_features as f64 * width_mult).ceil() as usize)
1394            .max(1)
1395            .min(self.in_features);
1396        let active_out = ((self.out_features as f64 * width_mult).ceil() as usize)
1397            .max(1)
1398            .min(self.out_features);
1399
1400        if input.len() < active_in {
1401            return Err(TensorError::compute_error_simple(format!(
1402                "SlimmableLinear: input len {} < active_in {}",
1403                input.len(),
1404                active_in,
1405            )));
1406        }
1407
1408        let mut output = Vec::with_capacity(active_out);
1409        for o in 0..active_out {
1410            let mut val = self.bias[o];
1411            for i in 0..active_in {
1412                val += self.weight[o * self.in_features + i] * input[i];
1413            }
1414            output.push(val);
1415        }
1416        Ok(output)
1417    }
1418}
1419
1420/// A slimmable MLP that supports runtime width adaptation.
1421///
1422/// Implements the approach from "Slimmable Neural Networks" (Yu et al., 2019):
1423/// a single network trained at multiple widths (0.25x, 0.5x, 0.75x, 1.0x)
1424/// and selectable at inference time.
1425#[derive(Debug, Clone)]
1426pub struct DynamicWidthNetwork {
1427    pub layers: Vec<EoSlimmableLinear>,
1428    /// Supported width multipliers.
1429    pub width_options: Vec<f64>,
1430}
1431
1432impl DynamicWidthNetwork {
1433    /// Build a slimmable MLP with given layer sizes (at full width).
1434    pub fn new(layer_sizes: &[usize], seed: u64) -> Result<Self> {
1435        if layer_sizes.len() < 2 {
1436            return Err(TensorError::compute_error_simple(
1437                "DynamicWidthNetwork: need at least 2 layer sizes".to_string(),
1438            ));
1439        }
1440        let mut layers = Vec::with_capacity(layer_sizes.len() - 1);
1441        for i in 0..(layer_sizes.len() - 1) {
1442            layers.push(EoSlimmableLinear::new(
1443                layer_sizes[i],
1444                layer_sizes[i + 1],
1445                seed + i as u64,
1446            ));
1447        }
1448        Ok(Self {
1449            layers,
1450            width_options: vec![0.25, 0.5, 0.75, 1.0],
1451        })
1452    }
1453
1454    /// Forward pass at a given width multiplier with ReLU activations.
1455    pub fn forward_at_width(&self, input: &[f64], width_mult: f64) -> Result<Vec<f64>> {
1456        let mut x = input.to_vec();
1457        for (idx, layer) in self.layers.iter().enumerate() {
1458            x = layer.forward_at_width(&x, width_mult)?;
1459            // Apply ReLU to all but last layer
1460            if idx < self.layers.len() - 1 {
1461                for v in &mut x {
1462                    if *v < 0.0 {
1463                        *v = 0.0;
1464                    }
1465                }
1466            }
1467        }
1468        Ok(x)
1469    }
1470
1471    /// Inplace distillation loss: KL divergence between widest and given width.
1472    /// Returns sum of squared differences (simplified distillation loss).
1473    pub fn inplace_distillation_loss(&self, input: &[f64], width_mult: f64) -> Result<f64> {
1474        let teacher_out = self.forward_at_width(input, 1.0)?;
1475        let student_out = self.forward_at_width(input, width_mult)?;
1476        let n = teacher_out.len().min(student_out.len());
1477        let loss: f64 = (0..n)
1478            .map(|i| (teacher_out[i] - student_out[i]).powi(2))
1479            .sum::<f64>()
1480            / n.max(1) as f64;
1481        Ok(loss)
1482    }
1483}
1484
1485// ============================================================================
1486// 8. IntegerArithmetic  (Fixed-point Q-format inference)
1487// ============================================================================
1488
1489/// Fixed-point number with configurable integer and fractional bits.
1490/// Representation: value = raw / 2^frac_bits.
1491#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1492pub struct FixedPoint {
1493    pub raw: i32,
1494    pub frac_bits: u8,
1495}
1496
1497impl FixedPoint {
1498    /// Create from a floating-point value.
1499    pub fn from_f64(value: f64, frac_bits: u8) -> Self {
1500        let scale = (1_i64 << frac_bits) as f64;
1501        let raw = (value * scale)
1502            .round()
1503            .clamp(i32::MIN as f64, i32::MAX as f64) as i32;
1504        Self { raw, frac_bits }
1505    }
1506
1507    /// Convert back to f64.
1508    pub fn to_f64(self) -> f64 {
1509        self.raw as f64 / (1_i64 << self.frac_bits) as f64
1510    }
1511
1512    /// Fixed-point multiply.  Result has same frac_bits.
1513    #[allow(clippy::should_implement_trait)]
1514    pub fn mul(self, other: Self) -> Self {
1515        let product = (self.raw as i64) * (other.raw as i64);
1516        let shifted = product >> self.frac_bits;
1517        Self {
1518            raw: shifted.clamp(i32::MIN as i64, i32::MAX as i64) as i32,
1519            frac_bits: self.frac_bits,
1520        }
1521    }
1522
1523    /// Fixed-point add.
1524    #[allow(clippy::should_implement_trait)]
1525    pub fn add(self, other: Self) -> Self {
1526        Self {
1527            raw: self.raw.saturating_add(other.raw),
1528            frac_bits: self.frac_bits,
1529        }
1530    }
1531}
1532
1533/// Integer-only linear layer: weights in i8, bias in i32, output scale in f64.
1534///
1535/// Performs: output = (W_int8 * input_int8 + bias_int32) * output_scale
1536#[derive(Debug, Clone)]
1537pub struct IntegerLinear {
1538    /// Quantized weights: shape (out_features x in_features), row-major.
1539    pub weight_i8: Vec<i8>,
1540    /// Quantized bias: shape (out_features).
1541    pub bias_i32: Vec<i32>,
1542    /// Scale factors: output = accumulator * output_scale.
1543    pub input_scale: f64,
1544    pub weight_scale: f64,
1545    pub output_scale: f64,
1546    pub in_features: usize,
1547    pub out_features: usize,
1548}
1549
1550impl IntegerLinear {
1551    /// Create from float weights and bias.
1552    /// Quantizes weights to i8 and bias to i32 with symmetric quantization.
1553    pub fn from_float(
1554        weights: &[f64],
1555        bias: &[f64],
1556        in_features: usize,
1557        out_features: usize,
1558    ) -> Result<Self> {
1559        if weights.len() != out_features * in_features {
1560            return Err(TensorError::compute_error_simple(format!(
1561                "IntegerLinear: weight size {} != {}x{}",
1562                weights.len(),
1563                out_features,
1564                in_features,
1565            )));
1566        }
1567        if bias.len() != out_features {
1568            return Err(TensorError::compute_error_simple(format!(
1569                "IntegerLinear: bias size {} != {}",
1570                bias.len(),
1571                out_features,
1572            )));
1573        }
1574
1575        // Symmetric quantization for weights
1576        let w_max = weights.iter().map(|w| w.abs()).fold(0.0_f64, f64::max);
1577        let weight_scale = if w_max > 1e-30 { w_max / 127.0 } else { 1.0 };
1578        let weight_i8: Vec<i8> = weights
1579            .iter()
1580            .map(|&w| (w / weight_scale).round().clamp(-128.0, 127.0) as i8)
1581            .collect();
1582
1583        // For input quantization, assume input is in [-1, 1] by default
1584        let input_scale = 1.0 / 127.0;
1585
1586        // Bias quantization to i32 with combined scale
1587        let bias_scale = input_scale * weight_scale;
1588        let bias_i32: Vec<i32> = bias
1589            .iter()
1590            .map(|&b| {
1591                if bias_scale > 1e-30 {
1592                    (b / bias_scale)
1593                        .round()
1594                        .clamp(i32::MIN as f64, i32::MAX as f64) as i32
1595                } else {
1596                    0
1597                }
1598            })
1599            .collect();
1600
1601        let output_scale = input_scale * weight_scale;
1602
1603        Ok(Self {
1604            weight_i8,
1605            bias_i32,
1606            input_scale,
1607            weight_scale,
1608            output_scale,
1609            in_features,
1610            out_features,
1611        })
1612    }
1613
1614    /// Integer-only forward pass.  `input_i8` is pre-quantized to i8.
1615    pub fn forward_int(&self, input_i8: &[i8]) -> Result<Vec<i32>> {
1616        if input_i8.len() < self.in_features {
1617            return Err(TensorError::compute_error_simple(format!(
1618                "IntegerLinear forward: input len {} < in_features {}",
1619                input_i8.len(),
1620                self.in_features,
1621            )));
1622        }
1623        let mut output = Vec::with_capacity(self.out_features);
1624        for o in 0..self.out_features {
1625            let mut acc: i32 = self.bias_i32[o];
1626            for i in 0..self.in_features {
1627                acc = acc.saturating_add(
1628                    (self.weight_i8[o * self.in_features + i] as i32) * (input_i8[i] as i32),
1629                );
1630            }
1631            output.push(acc);
1632        }
1633        Ok(output)
1634    }
1635
1636    /// Forward pass from float input: quantize, compute, dequantize.
1637    pub fn forward_float(&self, input: &[f64]) -> Result<Vec<f64>> {
1638        let input_i8: Vec<i8> = input
1639            .iter()
1640            .map(|&x| (x / self.input_scale).round().clamp(-128.0, 127.0) as i8)
1641            .collect();
1642        let acc = self.forward_int(&input_i8)?;
1643        Ok(acc.iter().map(|&a| a as f64 * self.output_scale).collect())
1644    }
1645
1646    /// Quantized ReLU: clamp accumulator to [0, max].
1647    pub fn quantized_relu(acc: &[i32]) -> Vec<i32> {
1648        acc.iter().map(|&a| a.max(0)).collect()
1649    }
1650
1651    /// Piecewise-linear sigmoid approximation for integer accumulators.
1652    /// Maps i32 accumulator (with implied scale) to [0, 1] range output as i32 with 8 fractional bits.
1653    pub fn quantized_sigmoid_approx(acc: &[i32], scale: f64) -> Vec<i32> {
1654        // Piecewise linear: 0 for x < -4, 1 for x > 4, linear in between
1655        let frac_bits = 8;
1656        let one = 1 << frac_bits; // 256 = fixed-point 1.0
1657        acc.iter()
1658            .map(|&a| {
1659                let x = a as f64 * scale;
1660                let y = if x < -4.0 {
1661                    0.0
1662                } else if x > 4.0 {
1663                    1.0
1664                } else {
1665                    // Simple piecewise: y = 0.125 * x + 0.5
1666                    (0.125 * x + 0.5).clamp(0.0, 1.0)
1667                };
1668                (y * one as f64).round() as i32
1669            })
1670            .collect()
1671    }
1672}
1673
1674// ============================================================================
1675// 9. MemoryBudgetAllocator
1676// ============================================================================
1677
1678/// Types of operations for memory planning.
1679#[derive(Debug, Clone, PartialEq)]
1680pub enum EoLayerType {
1681    Conv {
1682        in_ch: usize,
1683        out_ch: usize,
1684        kernel: usize,
1685    },
1686    Linear {
1687        in_feat: usize,
1688        out_feat: usize,
1689    },
1690    BatchNorm {
1691        channels: usize,
1692    },
1693    Relu,
1694    Pool {
1695        factor: usize,
1696    },
1697    Custom {
1698        name: String,
1699        memory_bytes: usize,
1700    },
1701}
1702
1703/// Description of a model layer for memory planning.
1704#[derive(Debug, Clone)]
1705pub struct EoLayerDesc {
1706    pub name: String,
1707    pub layer_type: EoLayerType,
1708    /// Spatial resolution at this layer (H, W) or (seq_len, 1) for 1D.
1709    pub spatial: (usize, usize),
1710    /// Batch size.
1711    pub batch_size: usize,
1712}
1713
1714/// Fusion opportunity detected by the allocator.
1715#[derive(Debug, Clone)]
1716pub struct EoFusionOp {
1717    /// Indices of layers fused together.
1718    pub layer_indices: Vec<usize>,
1719    /// Description of the fused op.
1720    pub description: String,
1721    /// Memory saved (bytes).
1722    pub memory_saved: usize,
1723}
1724
1725/// Memory allocation plan for a model under a budget.
1726#[derive(Debug, Clone)]
1727pub struct EoAllocationPlan {
1728    /// Per-layer activation memory (bytes).
1729    pub layer_memory: Vec<usize>,
1730    /// Which layers to checkpoint (recompute instead of storing activations).
1731    pub checkpoint_layers: Vec<usize>,
1732    /// Detected fusion opportunities.
1733    pub fusions: Vec<EoFusionOp>,
1734    /// Peak memory estimate (bytes).
1735    pub peak_memory_bytes: usize,
1736    /// Whether the plan fits within budget.
1737    pub fits_budget: bool,
1738}
1739
1740/// Plans memory allocation under a fixed budget.
1741///
1742/// Determines activation checkpointing schedule, detects operator fusion
1743/// opportunities (conv+bn+relu), and estimates peak memory.
1744#[derive(Debug, Clone)]
1745pub struct MemoryBudgetAllocator {
1746    pub budget_bytes: usize,
1747}
1748
1749impl MemoryBudgetAllocator {
1750    pub fn new(budget_bytes: usize) -> Self {
1751        Self { budget_bytes }
1752    }
1753
1754    /// Estimate activation memory for a single layer (in bytes, assuming f32).
1755    pub fn estimate_layer_memory(layer: &EoLayerDesc) -> usize {
1756        let (h, w) = layer.spatial;
1757        let batch = layer.batch_size;
1758        match &layer.layer_type {
1759            EoLayerType::Conv { out_ch, .. } => batch * (*out_ch) * h * w * 4,
1760            EoLayerType::Linear { out_feat, .. } => batch * (*out_feat) * 4,
1761            EoLayerType::BatchNorm { channels } => batch * (*channels) * h * w * 4,
1762            EoLayerType::Relu => batch * h * w * 4, // same size as input, simplified
1763            EoLayerType::Pool { factor } => {
1764                let ph = (h + factor - 1) / factor;
1765                let pw = (w + factor - 1) / factor;
1766                batch * ph * pw * 4
1767            }
1768            EoLayerType::Custom { memory_bytes, .. } => *memory_bytes,
1769        }
1770    }
1771
1772    /// Detect Conv + BatchNorm + ReLU fusion patterns.
1773    pub fn detect_fusions(layers: &[EoLayerDesc]) -> Vec<EoFusionOp> {
1774        let mut fusions = Vec::new();
1775        let n = layers.len();
1776        let mut i = 0;
1777        while i + 2 < n {
1778            let is_conv = matches!(&layers[i].layer_type, EoLayerType::Conv { .. });
1779            let is_bn = matches!(&layers[i + 1].layer_type, EoLayerType::BatchNorm { .. });
1780            let is_relu = matches!(&layers[i + 2].layer_type, EoLayerType::Relu);
1781            if is_conv && is_bn && is_relu {
1782                let bn_mem = Self::estimate_layer_memory(&layers[i + 1]);
1783                let relu_mem = Self::estimate_layer_memory(&layers[i + 2]);
1784                fusions.push(EoFusionOp {
1785                    layer_indices: vec![i, i + 1, i + 2],
1786                    description: format!(
1787                        "Fuse {}/{}/{} -> single Conv+BN+ReLU",
1788                        layers[i].name,
1789                        layers[i + 1].name,
1790                        layers[i + 2].name
1791                    ),
1792                    memory_saved: bn_mem + relu_mem,
1793                });
1794                i += 3;
1795            } else {
1796                i += 1;
1797            }
1798        }
1799        fusions
1800    }
1801
1802    /// Plan memory allocation for a model.
1803    pub fn plan_memory(&self, layers: &[EoLayerDesc]) -> EoAllocationPlan {
1804        let layer_memory: Vec<usize> = layers
1805            .iter()
1806            .map(Self::estimate_layer_memory)
1807            .collect();
1808        let fusions = Self::detect_fusions(layers);
1809
1810        // Total memory without optimization
1811        let total_no_opt: usize = layer_memory.iter().sum();
1812
1813        // Subtract fusion savings
1814        let fusion_savings: usize = fusions.iter().map(|f| f.memory_saved).sum();
1815        let total_after_fusion = total_no_opt.saturating_sub(fusion_savings);
1816
1817        // If still over budget, select layers to checkpoint (largest first, skip first/last)
1818        let mut checkpoint_layers = Vec::new();
1819        let mut current_mem = total_after_fusion;
1820
1821        if current_mem > self.budget_bytes && layers.len() > 2 {
1822            // Sort layer indices by memory (descending), skip first and last
1823            let mut sorted: Vec<(usize, usize)> = layer_memory
1824                .iter()
1825                .enumerate()
1826                .filter(|&(i, _)| i > 0 && i < layers.len() - 1)
1827                .map(|(i, &m)| (i, m))
1828                .collect();
1829            sorted.sort_by_key(|a| std::cmp::Reverse(a.1));
1830
1831            for (idx, mem) in sorted {
1832                if current_mem <= self.budget_bytes {
1833                    break;
1834                }
1835                // Checkpointing saves ~50% of that layer's activation memory
1836                // (we still need to store enough to recompute)
1837                let savings = mem / 2;
1838                current_mem = current_mem.saturating_sub(savings);
1839                checkpoint_layers.push(idx);
1840            }
1841            checkpoint_layers.sort();
1842        }
1843
1844        let peak_memory_bytes = current_mem;
1845        let fits_budget = peak_memory_bytes <= self.budget_bytes;
1846
1847        EoAllocationPlan {
1848            layer_memory,
1849            checkpoint_layers,
1850            fusions,
1851            peak_memory_bytes,
1852            fits_budget,
1853        }
1854    }
1855}
1856
1857// ============================================================================
1858// 10. EdgeMetrics & EdgeReport
1859// ============================================================================
1860
1861/// Collection of edge/mobile efficiency metrics.
1862#[derive(Debug, Clone)]
1863pub struct EdgeMetrics {
1864    /// Compression ratio (original / compressed).
1865    pub compression_ratio: f64,
1866    /// Speedup factor vs baseline.
1867    pub speedup_factor: f64,
1868    /// Memory footprint in bytes.
1869    pub memory_footprint_bytes: usize,
1870    /// Model size in bytes (parameters only).
1871    pub model_size_bytes: usize,
1872    /// Number of operations (FLOPs).
1873    pub flops: f64,
1874    /// Accuracy (or proxy).
1875    pub accuracy: f64,
1876}
1877
1878impl EdgeMetrics {
1879    /// Compute metrics from before/after comparison.
1880    pub fn compute(
1881        original_params: usize,
1882        compressed_params: usize,
1883        original_flops: f64,
1884        compressed_flops: f64,
1885        memory_bytes: usize,
1886        accuracy: f64,
1887    ) -> Self {
1888        let compression_ratio = if compressed_params > 0 {
1889            original_params as f64 / compressed_params as f64
1890        } else {
1891            0.0
1892        };
1893        let speedup_factor = if compressed_flops > 0.0 {
1894            original_flops / compressed_flops
1895        } else {
1896            0.0
1897        };
1898        Self {
1899            compression_ratio,
1900            speedup_factor,
1901            memory_footprint_bytes: memory_bytes,
1902            model_size_bytes: compressed_params * 4, // assume f32
1903            flops: compressed_flops,
1904            accuracy,
1905        }
1906    }
1907
1908    /// Efficiency score: harmonic mean of normalized accuracy and compression.
1909    pub fn efficiency_score(&self) -> f64 {
1910        let a = self.accuracy.clamp(0.0, 1.0);
1911        let c = (self.compression_ratio / 10.0).clamp(0.0, 1.0); // normalize assuming max 10x
1912        if a + c > 0.0 {
1913            2.0 * a * c / (a + c)
1914        } else {
1915            0.0
1916        }
1917    }
1918}
1919
1920/// Comprehensive report for edge deployment analysis.
1921#[derive(Debug, Clone)]
1922pub struct EdgeReport {
1923    pub model_name: String,
1924    pub target_device: String,
1925    pub metrics: EdgeMetrics,
1926    /// Pareto-optimal candidates from HW-aware search (if performed).
1927    pub pareto_candidates: Vec<EoArchCandidate>,
1928    /// Memory allocation plan (if computed).
1929    pub allocation_plan: Option<EoAllocationPlan>,
1930    /// Decomposition errors (if tensor decomposition was applied).
1931    pub decomposition_errors: Vec<f64>,
1932}
1933
1934impl EdgeReport {
1935    pub fn new(model_name: &str, target_device: &str, metrics: EdgeMetrics) -> Self {
1936        Self {
1937            model_name: model_name.to_string(),
1938            target_device: target_device.to_string(),
1939            metrics,
1940            pareto_candidates: Vec::new(),
1941            allocation_plan: None,
1942            decomposition_errors: Vec::new(),
1943        }
1944    }
1945
1946    /// Pretty-print the report.
1947    pub fn summary(&self) -> String {
1948        let mut s = String::new();
1949        s.push_str(&format!("=== Edge Report: {} ===\n", self.model_name));
1950        s.push_str(&format!("Target: {}\n", self.target_device));
1951        s.push_str(&format!(
1952            "Compression: {:.2}x\n",
1953            self.metrics.compression_ratio
1954        ));
1955        s.push_str(&format!("Speedup: {:.2}x\n", self.metrics.speedup_factor));
1956        s.push_str(&format!(
1957            "Memory: {} bytes\n",
1958            self.metrics.memory_footprint_bytes
1959        ));
1960        s.push_str(&format!(
1961            "Model size: {} bytes\n",
1962            self.metrics.model_size_bytes
1963        ));
1964        s.push_str(&format!("FLOPs: {:.0}\n", self.metrics.flops));
1965        s.push_str(&format!("Accuracy: {:.4}\n", self.metrics.accuracy));
1966        s.push_str(&format!(
1967            "Efficiency score: {:.4}\n",
1968            self.metrics.efficiency_score()
1969        ));
1970        if !self.pareto_candidates.is_empty() {
1971            s.push_str(&format!(
1972                "Pareto candidates: {}\n",
1973                self.pareto_candidates.len()
1974            ));
1975        }
1976        if let Some(ref plan) = self.allocation_plan {
1977            s.push_str(&format!(
1978                "Memory plan: peak={} bytes, fits={}\n",
1979                plan.peak_memory_bytes, plan.fits_budget
1980            ));
1981        }
1982        s
1983    }
1984
1985    /// Check if the Pareto frontier dominance analysis found any candidates.
1986    pub fn pareto_analysis_summary(&self) -> Vec<(f64, f64)> {
1987        self.pareto_candidates
1988            .iter()
1989            .map(|c| (c.accuracy_proxy, c.estimated_latency_ms))
1990            .collect()
1991    }
1992}
1993
1994// ============================================================================
1995// Tests
1996// ============================================================================
1997
1998#[cfg(test)]
1999mod tests;