Skip to main content

scirs2_neural/mobile/
mod.rs

1//! Mobile neural network deployment and optimisation
2//!
3//! This module provides lightweight building blocks for deploying neural networks
4//! on mobile and edge devices:
5//!
6//! - [`MobileNetConfig`] – configuration for MobileNet-style architectures
7//! - [`DepthwiseSeparableConv`] – depthwise + pointwise convolution layer
8//! - [`MobileNetV2Block`] – inverted residual block (MobileNetV2 bottleneck)
9//! - [`MobileOptimizer`] – quantization / pruning utilities for mobile deployment
10//!
11//! # References
12//! - Howard et al., "MobileNets", 2017 <https://arxiv.org/abs/1704.04861>
13//! - Sandler et al., "MobileNetV2", 2018 <https://arxiv.org/abs/1801.04381>
14
15use crate::error::{NeuralError, Result};
16use serde::{Deserialize, Serialize};
17
18// ─────────────────────────────────────────────────────────────────────────────
19// MobileNetConfig
20// ─────────────────────────────────────────────────────────────────────────────
21
22/// Configuration for a MobileNet-style architecture.
23///
24/// # Examples
25/// ```
26/// use scirs2_neural::mobile::MobileNetConfig;
27///
28/// let cfg = MobileNetConfig::mobilenet_v1();
29/// assert_eq!(cfg.input_resolution, 224);
30/// ```
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct MobileNetConfig {
33    /// Width multiplier α (scales number of channels, typical 0.25–1.0)
34    pub width_multiplier: f64,
35    /// Input image resolution (square, typical 128/160/192/224)
36    pub input_resolution: usize,
37    /// Number of output classes
38    pub num_classes: usize,
39    /// MobileNet version
40    pub version: MobileNetVersion,
41    /// Dropout rate before the final classifier
42    pub dropout_rate: f64,
43    /// Whether to use batch normalisation
44    pub use_batch_norm: bool,
45}
46
47/// MobileNet architecture version.
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
49pub enum MobileNetVersion {
50    /// MobileNetV1 – depthwise separable convolutions
51    V1,
52    /// MobileNetV2 – inverted residuals + linear bottlenecks
53    V2,
54    /// MobileNetV3 – hard-swish, SE blocks, NAS-searched
55    V3Small,
56    /// MobileNetV3 Large variant
57    V3Large,
58}
59
60impl MobileNetConfig {
61    /// Standard MobileNetV1 configuration (1× width, 224×224 input).
62    pub fn mobilenet_v1() -> Self {
63        Self {
64            width_multiplier: 1.0,
65            input_resolution: 224,
66            num_classes: 1000,
67            version: MobileNetVersion::V1,
68            dropout_rate: 0.001,
69            use_batch_norm: true,
70        }
71    }
72
73    /// Standard MobileNetV2 configuration (1× width, 224×224 input).
74    pub fn mobilenet_v2() -> Self {
75        Self {
76            width_multiplier: 1.0,
77            input_resolution: 224,
78            num_classes: 1000,
79            version: MobileNetVersion::V2,
80            dropout_rate: 0.2,
81            use_batch_norm: true,
82        }
83    }
84
85    /// Lightweight configuration suitable for low-power devices (0.25× width, 128×128).
86    pub fn mobile_lite() -> Self {
87        Self {
88            width_multiplier: 0.25,
89            input_resolution: 128,
90            num_classes: 10,
91            version: MobileNetVersion::V2,
92            dropout_rate: 0.0,
93            use_batch_norm: true,
94        }
95    }
96
97    /// Compute the number of channels at a given base channel count after applying the width multiplier.
98    pub fn scaled_channels(&self, base: usize) -> usize {
99        ((base as f64) * self.width_multiplier).round() as usize
100    }
101
102    /// Estimate total parameter count for a simple 4-layer depthwise separable network.
103    pub fn estimated_param_count(&self) -> usize {
104        // Rough estimate: 3×3 depthwise + 1×1 pointwise × 4 stages × channel size
105        let c = self.scaled_channels(32);
106        let dw_params = 3 * 3 * c; // depthwise kernel
107        let pw_params = c * (c * 2); // pointwise expanding
108        (dw_params + pw_params) * 4
109    }
110}
111
112// ─────────────────────────────────────────────────────────────────────────────
113// DepthwiseSeparableConv
114// ─────────────────────────────────────────────────────────────────────────────
115
116/// Depthwise-separable convolution layer (inference-only, f32 weights).
117///
118/// Factorises a standard K×K convolution into:
119/// 1. **Depthwise** 3×3 conv (one filter per input channel)
120/// 2. **Pointwise** 1×1 conv (mixes channels)
121///
122/// Weight layout:
123/// - depthwise weights: `[in_channels, kH, kW]`
124/// - pointwise weights: `[out_channels, in_channels]`
125/// - biases: `[out_channels]`
126///
127/// # Examples
128/// ```
129/// use scirs2_neural::mobile::DepthwiseSeparableConv;
130///
131/// let dsc = DepthwiseSeparableConv::new(8, 16, (3, 3)).expect("dsc");
132/// assert_eq!(dsc.in_channels(), 8);
133/// assert_eq!(dsc.out_channels(), 16);
134/// ```
135#[derive(Debug, Clone)]
136pub struct DepthwiseSeparableConv {
137    in_ch: usize,
138    out_ch: usize,
139    kernel_size: (usize, usize),
140    // depthwise weights [in_ch, kH, kW]
141    dw_weights: Vec<f32>,
142    // pointwise weights [out_ch, in_ch]
143    pw_weights: Vec<f32>,
144    // bias [out_ch]
145    bias: Vec<f32>,
146}
147
148impl DepthwiseSeparableConv {
149    /// Create a new layer with randomly-initialised He weights.
150    pub fn new(
151        in_channels: usize,
152        out_channels: usize,
153        kernel_size: (usize, usize),
154    ) -> Result<Self> {
155        if in_channels == 0 || out_channels == 0 {
156            return Err(NeuralError::InvalidArgument(
157                "DepthwiseSeparableConv: channel counts must be > 0".to_string(),
158            ));
159        }
160        let (kh, kw) = kernel_size;
161        let dw_size = in_channels * kh * kw;
162        let pw_size = out_channels * in_channels;
163
164        // He initialisation scale
165        let dw_scale = (2.0_f32 / (kh * kw) as f32).sqrt();
166        let pw_scale = (2.0_f32 / in_channels as f32).sqrt();
167
168        let dw_weights = pseudo_random_weights(dw_size, dw_scale, 1);
169        let pw_weights = pseudo_random_weights(pw_size, pw_scale, 2);
170        let bias = vec![0.0_f32; out_channels];
171
172        Ok(Self {
173            in_ch: in_channels,
174            out_ch: out_channels,
175            kernel_size,
176            dw_weights,
177            pw_weights,
178            bias,
179        })
180    }
181
182    /// Returns the number of input channels.
183    pub fn in_channels(&self) -> usize {
184        self.in_ch
185    }
186
187    /// Returns the number of output channels.
188    pub fn out_channels(&self) -> usize {
189        self.out_ch
190    }
191
192    /// Returns the kernel size.
193    pub fn kernel_size(&self) -> (usize, usize) {
194        self.kernel_size
195    }
196
197    /// Total trainable parameter count.
198    pub fn parameter_count(&self) -> usize {
199        self.dw_weights.len() + self.pw_weights.len() + self.bias.len()
200    }
201
202    /// Forward pass on a flat `[batch × in_ch × H × W]` f32 slice.
203    ///
204    /// `input_shape` must be `[batch, in_ch, H, W]`.
205    /// Returns a flat `[batch × out_ch × H_out × W_out]` vector.
206    pub fn forward(
207        &self,
208        input: &[f32],
209        input_shape: [usize; 4],
210    ) -> Result<(Vec<f32>, [usize; 4])> {
211        let [batch, in_ch, h, w] = input_shape;
212        if in_ch != self.in_ch {
213            return Err(NeuralError::ShapeMismatch(format!(
214                "DepthwiseSeparableConv: expected in_ch={}, got {}",
215                self.in_ch, in_ch
216            )));
217        }
218        if input.len() != batch * in_ch * h * w {
219            return Err(NeuralError::ShapeMismatch(
220                "DepthwiseSeparableConv: input slice length mismatch".to_string(),
221            ));
222        }
223
224        let (kh, kw) = self.kernel_size;
225        let padding = (kh / 2, kw / 2);
226        let h_out = (h + 2 * padding.0).saturating_sub(kh) + 1;
227        let w_out = (w + 2 * padding.1).saturating_sub(kw) + 1;
228
229        // ── 1. Depthwise conv ──────────────────────────────────────────────
230        let dw_size = batch * in_ch * h_out * w_out;
231        let mut dw_out = vec![0.0_f32; dw_size];
232
233        for b in 0..batch {
234            for c in 0..in_ch {
235                for oh in 0..h_out {
236                    for ow in 0..w_out {
237                        let mut acc = 0.0_f32;
238                        for ki in 0..kh {
239                            for kj in 0..kw {
240                                let ih = oh + ki;
241                                let iw = ow + kj;
242                                // Padding check (implicit zero-padding)
243                                let ih_src = ih.wrapping_sub(padding.0);
244                                let iw_src = iw.wrapping_sub(padding.1);
245                                if ih_src < h && iw_src < w {
246                                    let in_idx =
247                                        b * in_ch * h * w + c * h * w + ih_src * w + iw_src;
248                                    let w_idx = c * kh * kw + ki * kw + kj;
249                                    acc += input[in_idx] * self.dw_weights[w_idx];
250                                }
251                            }
252                        }
253                        // ReLU6
254                        let idx = b * in_ch * h_out * w_out + c * h_out * w_out + oh * w_out + ow;
255                        dw_out[idx] = acc.clamp(0.0, 6.0);
256                    }
257                }
258            }
259        }
260
261        // ── 2. Pointwise conv (1×1) ────────────────────────────────────────
262        let pw_size = batch * self.out_ch * h_out * w_out;
263        let mut pw_out = vec![0.0_f32; pw_size];
264
265        for b in 0..batch {
266            for oc in 0..self.out_ch {
267                for oh in 0..h_out {
268                    for ow in 0..w_out {
269                        let mut acc = self.bias[oc];
270                        for ic in 0..in_ch {
271                            let dw_idx =
272                                b * in_ch * h_out * w_out + ic * h_out * w_out + oh * w_out + ow;
273                            let pw_idx = oc * in_ch + ic;
274                            acc += dw_out[dw_idx] * self.pw_weights[pw_idx];
275                        }
276                        let out_idx =
277                            b * self.out_ch * h_out * w_out + oc * h_out * w_out + oh * w_out + ow;
278                        pw_out[out_idx] = acc.clamp(0.0, 6.0);
279                    }
280                }
281            }
282        }
283
284        Ok((pw_out, [batch, self.out_ch, h_out, w_out]))
285    }
286}
287
288// ─────────────────────────────────────────────────────────────────────────────
289// MobileNetV2Block
290// ─────────────────────────────────────────────────────────────────────────────
291
292/// MobileNetV2 inverted residual block.
293///
294/// Structure (for stride=1 with residual shortcut):
295/// ```text
296/// input  ──[PW expand]──[DW 3×3]──[PW project]──⊕── output
297///        └────────────────────────────────────────┘
298/// ```
299/// When `stride > 1` or `in_channels != out_channels`, no residual is added.
300///
301/// # Examples
302/// ```
303/// use scirs2_neural::mobile::MobileNetV2Block;
304///
305/// let block = MobileNetV2Block::new(32, 16, 6, 1).expect("block");
306/// assert_eq!(block.out_channels(), 16);
307/// ```
308#[derive(Debug, Clone)]
309pub struct MobileNetV2Block {
310    in_ch: usize,
311    out_ch: usize,
312    expansion: usize,
313    stride: usize,
314    /// Expansion pointwise: in_ch → expanded_ch
315    expand_pw: Option<PointwiseConv>,
316    /// Depthwise: expanded_ch × 3×3
317    dw: DepthwiseSeparableConv,
318    /// Projection pointwise: expanded_ch → out_ch (no activation)
319    project_pw: PointwiseConv,
320    /// Whether to add a residual shortcut
321    use_residual: bool,
322}
323
324impl MobileNetV2Block {
325    /// Create a new inverted residual block.
326    ///
327    /// # Arguments
328    /// * `in_channels` – input feature maps
329    /// * `out_channels` – output feature maps
330    /// * `expansion_factor` – channel expansion multiplier (typically 6)
331    /// * `stride` – depthwise convolution stride (1 or 2)
332    pub fn new(
333        in_channels: usize,
334        out_channels: usize,
335        expansion_factor: usize,
336        stride: usize,
337    ) -> Result<Self> {
338        if in_channels == 0 || out_channels == 0 {
339            return Err(NeuralError::InvalidArgument(
340                "MobileNetV2Block: channel counts must be > 0".to_string(),
341            ));
342        }
343        if stride == 0 {
344            return Err(NeuralError::InvalidArgument(
345                "MobileNetV2Block: stride must be >= 1".to_string(),
346            ));
347        }
348
349        let expanded_ch = in_channels * expansion_factor;
350
351        // Expand PW (skip for expansion factor == 1)
352        let expand_pw = if expansion_factor != 1 {
353            Some(PointwiseConv::new(in_channels, expanded_ch)?)
354        } else {
355            None
356        };
357
358        // Depthwise conv on expanded channels
359        let dw = DepthwiseSeparableConv::new(expanded_ch, expanded_ch, (3, 3))?;
360        // Project PW (no activation)
361        let project_pw = PointwiseConv::new(expanded_ch, out_channels)?;
362
363        let use_residual = stride == 1 && in_channels == out_channels;
364
365        Ok(Self {
366            in_ch: in_channels,
367            out_ch: out_channels,
368            expansion: expansion_factor,
369            stride,
370            expand_pw,
371            dw,
372            project_pw,
373            use_residual,
374        })
375    }
376
377    /// Returns the number of input channels.
378    pub fn in_channels(&self) -> usize {
379        self.in_ch
380    }
381
382    /// Returns the number of output channels.
383    pub fn out_channels(&self) -> usize {
384        self.out_ch
385    }
386
387    /// Returns the expansion factor.
388    pub fn expansion(&self) -> usize {
389        self.expansion
390    }
391
392    /// Returns the stride.
393    pub fn stride(&self) -> usize {
394        self.stride
395    }
396
397    /// Returns whether this block uses a residual shortcut.
398    pub fn has_residual(&self) -> bool {
399        self.use_residual
400    }
401
402    /// Total parameter count.
403    pub fn parameter_count(&self) -> usize {
404        let expand = self
405            .expand_pw
406            .as_ref()
407            .map(|p| p.parameter_count())
408            .unwrap_or(0);
409        expand + self.dw.parameter_count() + self.project_pw.parameter_count()
410    }
411
412    /// Forward pass.
413    ///
414    /// Input: flat `[batch × in_ch × H × W]` f32 slice.
415    /// Returns: `(output_flat, [batch, out_ch, H_out, W_out])`.
416    pub fn forward(&self, input: &[f32], shape: [usize; 4]) -> Result<(Vec<f32>, [usize; 4])> {
417        let [batch, in_ch, h, w] = shape;
418        if in_ch != self.in_ch {
419            return Err(NeuralError::ShapeMismatch(format!(
420                "MobileNetV2Block: expected in_ch={}, got {}",
421                self.in_ch, in_ch
422            )));
423        }
424
425        // ── Expand PW ─────────────────────────────────────────────────────
426        let (expanded, expanded_shape) = if let Some(ref pw) = self.expand_pw {
427            pw.forward_with_relu6(input, shape)?
428        } else {
429            (input.to_vec(), shape)
430        };
431
432        // ── Depthwise conv ─────────────────────────────────────────────────
433        // We only use the depthwise part (forward already includes pointwise,
434        // but we manually call the depthwise and skip the pointwise channel-mix)
435        let (dw_out, dw_shape) = depthwise_only(
436            &expanded,
437            expanded_shape,
438            &self.dw.dw_weights,
439            self.dw.kernel_size,
440            self.stride,
441        )?;
442
443        // ── Project PW ────────────────────────────────────────────────────
444        let (projected, proj_shape) = self
445            .project_pw
446            .forward_linear(dw_out.as_slice(), dw_shape)?;
447
448        // ── Residual shortcut ─────────────────────────────────────────────
449        let output = if self.use_residual {
450            input
451                .iter()
452                .zip(projected.iter())
453                .map(|(a, b)| a + b)
454                .collect()
455        } else {
456            projected
457        };
458
459        Ok((output, proj_shape))
460    }
461}
462
463// ─────────────────────────────────────────────────────────────────────────────
464// PointwiseConv (internal helper)
465// ─────────────────────────────────────────────────────────────────────────────
466
467/// 1×1 pointwise convolution (internal helper).
468#[derive(Debug, Clone)]
469struct PointwiseConv {
470    in_ch: usize,
471    out_ch: usize,
472    weights: Vec<f32>, // [out_ch, in_ch]
473    bias: Vec<f32>,    // [out_ch]
474}
475
476impl PointwiseConv {
477    fn new(in_channels: usize, out_channels: usize) -> Result<Self> {
478        let size = out_channels * in_channels;
479        let scale = (2.0_f32 / in_channels as f32).sqrt();
480        Ok(Self {
481            in_ch: in_channels,
482            out_ch: out_channels,
483            weights: pseudo_random_weights(size, scale, 3),
484            bias: vec![0.0_f32; out_channels],
485        })
486    }
487
488    fn parameter_count(&self) -> usize {
489        self.weights.len() + self.bias.len()
490    }
491
492    /// Forward with ReLU6 activation.
493    fn forward_with_relu6(
494        &self,
495        input: &[f32],
496        shape: [usize; 4],
497    ) -> Result<(Vec<f32>, [usize; 4])> {
498        let [batch, in_ch, h, w] = shape;
499        if in_ch != self.in_ch {
500            return Err(NeuralError::ShapeMismatch(format!(
501                "PointwiseConv: in_ch mismatch {} vs {}",
502                self.in_ch, in_ch
503            )));
504        }
505        let out_size = batch * self.out_ch * h * w;
506        let mut out = vec![0.0_f32; out_size];
507
508        for b in 0..batch {
509            for oc in 0..self.out_ch {
510                for ph in 0..h {
511                    for pw_pos in 0..w {
512                        let mut acc = self.bias[oc];
513                        for ic in 0..in_ch {
514                            let in_idx = b * in_ch * h * w + ic * h * w + ph * w + pw_pos;
515                            acc += input[in_idx] * self.weights[oc * in_ch + ic];
516                        }
517                        let out_idx = b * self.out_ch * h * w + oc * h * w + ph * w + pw_pos;
518                        out[out_idx] = acc.clamp(0.0, 6.0);
519                    }
520                }
521            }
522        }
523        Ok((out, [batch, self.out_ch, h, w]))
524    }
525
526    /// Forward without activation (used for projection PW in V2 block).
527    fn forward_linear(&self, input: &[f32], shape: [usize; 4]) -> Result<(Vec<f32>, [usize; 4])> {
528        let [batch, in_ch, h, w] = shape;
529        if in_ch != self.in_ch {
530            return Err(NeuralError::ShapeMismatch(format!(
531                "PointwiseConv(linear): in_ch mismatch {} vs {}",
532                self.in_ch, in_ch
533            )));
534        }
535        let out_size = batch * self.out_ch * h * w;
536        let mut out = vec![0.0_f32; out_size];
537        for b in 0..batch {
538            for oc in 0..self.out_ch {
539                for ph in 0..h {
540                    for pw_pos in 0..w {
541                        let mut acc = self.bias[oc];
542                        for ic in 0..in_ch {
543                            let in_idx = b * in_ch * h * w + ic * h * w + ph * w + pw_pos;
544                            acc += input[in_idx] * self.weights[oc * in_ch + ic];
545                        }
546                        let out_idx = b * self.out_ch * h * w + oc * h * w + ph * w + pw_pos;
547                        out[out_idx] = acc;
548                    }
549                }
550            }
551        }
552        Ok((out, [batch, self.out_ch, h, w]))
553    }
554}
555
556// ─────────────────────────────────────────────────────────────────────────────
557// Depthwise-only helper
558// ─────────────────────────────────────────────────────────────────────────────
559
560/// Runs just the depthwise convolution part (no channel mixing).
561fn depthwise_only(
562    input: &[f32],
563    shape: [usize; 4],
564    weights: &[f32],
565    kernel_size: (usize, usize),
566    stride: usize,
567) -> Result<(Vec<f32>, [usize; 4])> {
568    let [batch, channels, h, w] = shape;
569    let (kh, kw) = kernel_size;
570    let padding = (kh / 2, kw / 2);
571    let h_out = if stride == 1 {
572        h
573    } else {
574        (h + 2 * padding.0).saturating_sub(kh) / stride + 1
575    };
576    let w_out = if stride == 1 {
577        w
578    } else {
579        (w + 2 * padding.1).saturating_sub(kw) / stride + 1
580    };
581
582    let mut out = vec![0.0_f32; batch * channels * h_out * w_out];
583    for b in 0..batch {
584        for c in 0..channels {
585            for oh in 0..h_out {
586                for ow in 0..w_out {
587                    let mut acc = 0.0_f32;
588                    for ki in 0..kh {
589                        for kj in 0..kw {
590                            let ih = oh * stride + ki;
591                            let iw = ow * stride + kj;
592                            let ih_src = ih.wrapping_sub(padding.0);
593                            let iw_src = iw.wrapping_sub(padding.1);
594                            if ih_src < h && iw_src < w {
595                                let in_idx = b * channels * h * w + c * h * w + ih_src * w + iw_src;
596                                let w_idx = c * kh * kw + ki * kw + kj;
597                                acc += input[in_idx] * weights[w_idx];
598                            }
599                        }
600                    }
601                    let out_idx =
602                        b * channels * h_out * w_out + c * h_out * w_out + oh * w_out + ow;
603                    out[out_idx] = acc.clamp(0.0, 6.0);
604                }
605            }
606        }
607    }
608    Ok((out, [batch, channels, h_out, w_out]))
609}
610
611// ─────────────────────────────────────────────────────────────────────────────
612// MobileOptimizer
613// ─────────────────────────────────────────────────────────────────────────────
614
615/// Mobile-optimised model utilities.
616///
617/// Provides helpers to reduce model size and latency for mobile deployment.
618pub struct MobileOptimizer {
619    /// Target model size budget in kilobytes
620    pub size_budget_kb: f64,
621    /// Minimum acceptable accuracy drop (0.0–1.0)
622    pub max_accuracy_drop: f64,
623}
624
625impl MobileOptimizer {
626    /// Create a new mobile optimizer.
627    pub fn new(size_budget_kb: f64, max_accuracy_drop: f64) -> Result<Self> {
628        if size_budget_kb <= 0.0 {
629            return Err(NeuralError::InvalidArgument(
630                "size_budget_kb must be > 0".to_string(),
631            ));
632        }
633        Ok(Self {
634            size_budget_kb,
635            max_accuracy_drop: max_accuracy_drop.clamp(0.0, 1.0),
636        })
637    }
638
639    /// Estimate the byte size of a weight vector at a given bit-width.
640    pub fn estimate_size_bytes(num_weights: usize, bits_per_weight: u8) -> usize {
641        (num_weights * bits_per_weight as usize).div_ceil(8)
642    }
643
644    /// Quantise f32 weights to INT8 symmetric representation.
645    ///
646    /// Returns `(quantized, scale)`.
647    pub fn quantize_int8(weights: &[f32]) -> Result<(Vec<i8>, f32)> {
648        if weights.is_empty() {
649            return Err(NeuralError::InvalidArgument(
650                "quantize_int8: empty weights".to_string(),
651            ));
652        }
653        let abs_max = weights.iter().fold(0.0_f32, |acc, &v| acc.max(v.abs()));
654        let scale = if abs_max > 0.0 { abs_max / 127.0 } else { 1.0 };
655        let quantized: Vec<i8> = weights
656            .iter()
657            .map(|&w| (w / scale).round().clamp(-128.0, 127.0) as i8)
658            .collect();
659        Ok((quantized, scale))
660    }
661
662    /// Prune weights whose absolute value is below `threshold` (set to 0).
663    pub fn magnitude_prune(weights: &mut [f32], sparsity: f64) {
664        if weights.is_empty() || sparsity <= 0.0 {
665            return;
666        }
667        let n = weights.len();
668        let mut sorted_abs: Vec<f32> = weights.iter().map(|v| v.abs()).collect();
669        sorted_abs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
670        let cutoff_idx = ((sparsity.clamp(0.0, 1.0) * n as f64) as usize).min(n.saturating_sub(1));
671        let threshold = sorted_abs[cutoff_idx];
672        for w in weights.iter_mut() {
673            if w.abs() < threshold {
674                *w = 0.0;
675            }
676        }
677    }
678
679    /// Returns whether a model fits within the configured size budget.
680    pub fn fits_budget(&self, param_count: usize) -> bool {
681        let bytes = Self::estimate_size_bytes(param_count, 32);
682        (bytes as f64 / 1024.0) <= self.size_budget_kb
683    }
684}
685
686// ─────────────────────────────────────────────────────────────────────────────
687// Helpers
688// ─────────────────────────────────────────────────────────────────────────────
689
690/// Generate deterministic pseudo-random weights for testing (no external deps).
691fn pseudo_random_weights(n: usize, scale: f32, seed_offset: u64) -> Vec<f32> {
692    let mut state: u64 = 0xDEAD_BEEF_0000_0001u64.wrapping_add(seed_offset);
693    (0..n)
694        .map(|_| {
695            state = state
696                .wrapping_mul(6364136223846793005)
697                .wrapping_add(1442695040888963407);
698            let u = (state >> 33) as f32 / u32::MAX as f32; // [0, 1)
699            (u * 2.0 - 1.0) * scale
700        })
701        .collect()
702}
703
704// ─────────────────────────────────────────────────────────────────────────────
705// Tests
706// ─────────────────────────────────────────────────────────────────────────────
707
708#[cfg(test)]
709mod tests {
710    use super::*;
711
712    #[test]
713    fn test_mobile_net_config_v1() {
714        let cfg = MobileNetConfig::mobilenet_v1();
715        assert_eq!(cfg.input_resolution, 224);
716        assert_eq!(cfg.version, MobileNetVersion::V1);
717        assert!((cfg.width_multiplier - 1.0).abs() < 1e-6);
718    }
719
720    #[test]
721    fn test_mobile_net_config_v2() {
722        let cfg = MobileNetConfig::mobilenet_v2();
723        assert_eq!(cfg.version, MobileNetVersion::V2);
724    }
725
726    #[test]
727    fn test_scaled_channels() {
728        let cfg = MobileNetConfig {
729            width_multiplier: 0.5,
730            ..MobileNetConfig::mobilenet_v2()
731        };
732        assert_eq!(cfg.scaled_channels(32), 16);
733        assert_eq!(cfg.scaled_channels(64), 32);
734    }
735
736    #[test]
737    fn test_depthwise_separable_conv_creation() {
738        let dsc = DepthwiseSeparableConv::new(4, 8, (3, 3)).expect("dsc ok");
739        assert_eq!(dsc.in_channels(), 4);
740        assert_eq!(dsc.out_channels(), 8);
741        assert!(dsc.parameter_count() > 0);
742    }
743
744    #[test]
745    fn test_depthwise_separable_conv_forward() {
746        let dsc = DepthwiseSeparableConv::new(2, 4, (3, 3)).expect("dsc ok");
747        // batch=1, in_ch=2, H=8, W=8 (NCHW)
748        let input = vec![0.5_f32; 2 * 8 * 8];
749        let (output, out_shape) = dsc.forward(&input, [1, 2, 8, 8]).expect("forward ok");
750        let [b, c, h, w] = out_shape;
751        assert_eq!(b, 1);
752        assert_eq!(c, 4);
753        assert_eq!(h, 8); // same-padding
754        assert_eq!(w, 8);
755        assert_eq!(output.len(), b * c * h * w);
756    }
757
758    #[test]
759    fn test_depthwise_separable_conv_channel_mismatch_err() {
760        let dsc = DepthwiseSeparableConv::new(4, 8, (3, 3)).expect("dsc ok");
761        let input = vec![0.0_f32; 2 * 4 * 4]; // wrong channels (batch=1 NCHW)
762        let result = dsc.forward(&input, [1, 2, 4, 4]);
763        assert!(result.is_err());
764    }
765
766    #[test]
767    fn test_mobilenet_v2_block_creation() {
768        let block = MobileNetV2Block::new(32, 16, 6, 1).expect("block ok");
769        assert_eq!(block.in_channels(), 32);
770        assert_eq!(block.out_channels(), 16);
771        assert!(!block.has_residual()); // channels differ
772    }
773
774    #[test]
775    fn test_mobilenet_v2_block_residual() {
776        let block = MobileNetV2Block::new(16, 16, 6, 1).expect("block ok");
777        assert!(block.has_residual());
778    }
779
780    #[test]
781    fn test_mobilenet_v2_block_forward() {
782        let block = MobileNetV2Block::new(8, 8, 6, 1).expect("block ok");
783        let input = vec![0.1_f32; 8 * 4 * 4]; // batch=1 NCHW
784        let (output, out_shape) = block.forward(&input, [1, 8, 4, 4]).expect("fwd ok");
785        let [b, c, _h, _w] = out_shape;
786        assert_eq!(b, 1);
787        assert_eq!(c, 8);
788        assert_eq!(output.len(), 8 * 4 * 4);
789    }
790
791    #[test]
792    fn test_mobilenet_v2_block_stride2() {
793        let block = MobileNetV2Block::new(8, 16, 6, 2).expect("block ok");
794        assert!(!block.has_residual());
795        let input = vec![0.1_f32; 8 * 8 * 8]; // batch=1 NCHW
796        let (output, out_shape) = block.forward(&input, [1, 8, 8, 8]).expect("fwd ok");
797        let [b, c, h, w] = out_shape;
798        assert_eq!(b, 1);
799        assert_eq!(c, 16);
800        // With stride=2, spatial dims halve
801        assert!(h <= 4 && w <= 4, "expected ≤4, got h={h} w={w}");
802        assert_eq!(output.len(), b * c * h * w);
803    }
804
805    #[test]
806    fn test_mobile_optimizer_quantize_int8() {
807        let weights = vec![0.5_f32, -0.5, 1.0, -1.0, 0.0];
808        let (q, scale) = MobileOptimizer::quantize_int8(&weights).expect("ok");
809        assert_eq!(q.len(), weights.len());
810        let dequant: Vec<f32> = q.iter().map(|&v| v as f32 * scale).collect();
811        for (orig, deq) in weights.iter().zip(dequant.iter()) {
812            assert!((orig - deq).abs() < 0.01, "orig={orig} deq={deq}");
813        }
814    }
815
816    #[test]
817    fn test_mobile_optimizer_prune() {
818        let mut weights = vec![0.01_f32, 0.5, 0.001, 1.0, 0.002];
819        MobileOptimizer::magnitude_prune(&mut weights, 0.6);
820        // Bottom 60% (3 out of 5) should be zeroed
821        let zeros = weights.iter().filter(|&&v| v == 0.0).count();
822        assert!(zeros >= 2, "expected ≥2 zeros, got {zeros}");
823    }
824
825    #[test]
826    fn test_mobile_optimizer_budget() {
827        let opt = MobileOptimizer::new(1000.0, 0.01).expect("ok");
828        // 10 weights at FP32 = 40 bytes ≈ tiny, fits
829        assert!(opt.fits_budget(10));
830        // 10M weights at FP32 ≈ 40MB, won't fit in 1000 KB
831        assert!(!opt.fits_budget(10_000_000));
832    }
833
834    #[test]
835    fn test_depthwise_separable_conv_zero_channels_err() {
836        assert!(DepthwiseSeparableConv::new(0, 8, (3, 3)).is_err());
837    }
838}