1use crate::error::{NeuralError, Result};
16use serde::{Deserialize, Serialize};
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct MobileNetConfig {
33 pub width_multiplier: f64,
35 pub input_resolution: usize,
37 pub num_classes: usize,
39 pub version: MobileNetVersion,
41 pub dropout_rate: f64,
43 pub use_batch_norm: bool,
45}
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
49pub enum MobileNetVersion {
50 V1,
52 V2,
54 V3Small,
56 V3Large,
58}
59
60impl MobileNetConfig {
61 pub fn mobilenet_v1() -> Self {
63 Self {
64 width_multiplier: 1.0,
65 input_resolution: 224,
66 num_classes: 1000,
67 version: MobileNetVersion::V1,
68 dropout_rate: 0.001,
69 use_batch_norm: true,
70 }
71 }
72
73 pub fn mobilenet_v2() -> Self {
75 Self {
76 width_multiplier: 1.0,
77 input_resolution: 224,
78 num_classes: 1000,
79 version: MobileNetVersion::V2,
80 dropout_rate: 0.2,
81 use_batch_norm: true,
82 }
83 }
84
85 pub fn mobile_lite() -> Self {
87 Self {
88 width_multiplier: 0.25,
89 input_resolution: 128,
90 num_classes: 10,
91 version: MobileNetVersion::V2,
92 dropout_rate: 0.0,
93 use_batch_norm: true,
94 }
95 }
96
97 pub fn scaled_channels(&self, base: usize) -> usize {
99 ((base as f64) * self.width_multiplier).round() as usize
100 }
101
102 pub fn estimated_param_count(&self) -> usize {
104 let c = self.scaled_channels(32);
106 let dw_params = 3 * 3 * c; let pw_params = c * (c * 2); (dw_params + pw_params) * 4
109 }
110}
111
112#[derive(Debug, Clone)]
136pub struct DepthwiseSeparableConv {
137 in_ch: usize,
138 out_ch: usize,
139 kernel_size: (usize, usize),
140 dw_weights: Vec<f32>,
142 pw_weights: Vec<f32>,
144 bias: Vec<f32>,
146}
147
148impl DepthwiseSeparableConv {
149 pub fn new(
151 in_channels: usize,
152 out_channels: usize,
153 kernel_size: (usize, usize),
154 ) -> Result<Self> {
155 if in_channels == 0 || out_channels == 0 {
156 return Err(NeuralError::InvalidArgument(
157 "DepthwiseSeparableConv: channel counts must be > 0".to_string(),
158 ));
159 }
160 let (kh, kw) = kernel_size;
161 let dw_size = in_channels * kh * kw;
162 let pw_size = out_channels * in_channels;
163
164 let dw_scale = (2.0_f32 / (kh * kw) as f32).sqrt();
166 let pw_scale = (2.0_f32 / in_channels as f32).sqrt();
167
168 let dw_weights = pseudo_random_weights(dw_size, dw_scale, 1);
169 let pw_weights = pseudo_random_weights(pw_size, pw_scale, 2);
170 let bias = vec![0.0_f32; out_channels];
171
172 Ok(Self {
173 in_ch: in_channels,
174 out_ch: out_channels,
175 kernel_size,
176 dw_weights,
177 pw_weights,
178 bias,
179 })
180 }
181
182 pub fn in_channels(&self) -> usize {
184 self.in_ch
185 }
186
187 pub fn out_channels(&self) -> usize {
189 self.out_ch
190 }
191
192 pub fn kernel_size(&self) -> (usize, usize) {
194 self.kernel_size
195 }
196
197 pub fn parameter_count(&self) -> usize {
199 self.dw_weights.len() + self.pw_weights.len() + self.bias.len()
200 }
201
202 pub fn forward(
207 &self,
208 input: &[f32],
209 input_shape: [usize; 4],
210 ) -> Result<(Vec<f32>, [usize; 4])> {
211 let [batch, in_ch, h, w] = input_shape;
212 if in_ch != self.in_ch {
213 return Err(NeuralError::ShapeMismatch(format!(
214 "DepthwiseSeparableConv: expected in_ch={}, got {}",
215 self.in_ch, in_ch
216 )));
217 }
218 if input.len() != batch * in_ch * h * w {
219 return Err(NeuralError::ShapeMismatch(
220 "DepthwiseSeparableConv: input slice length mismatch".to_string(),
221 ));
222 }
223
224 let (kh, kw) = self.kernel_size;
225 let padding = (kh / 2, kw / 2);
226 let h_out = (h + 2 * padding.0).saturating_sub(kh) + 1;
227 let w_out = (w + 2 * padding.1).saturating_sub(kw) + 1;
228
229 let dw_size = batch * in_ch * h_out * w_out;
231 let mut dw_out = vec![0.0_f32; dw_size];
232
233 for b in 0..batch {
234 for c in 0..in_ch {
235 for oh in 0..h_out {
236 for ow in 0..w_out {
237 let mut acc = 0.0_f32;
238 for ki in 0..kh {
239 for kj in 0..kw {
240 let ih = oh + ki;
241 let iw = ow + kj;
242 let ih_src = ih.wrapping_sub(padding.0);
244 let iw_src = iw.wrapping_sub(padding.1);
245 if ih_src < h && iw_src < w {
246 let in_idx =
247 b * in_ch * h * w + c * h * w + ih_src * w + iw_src;
248 let w_idx = c * kh * kw + ki * kw + kj;
249 acc += input[in_idx] * self.dw_weights[w_idx];
250 }
251 }
252 }
253 let idx = b * in_ch * h_out * w_out + c * h_out * w_out + oh * w_out + ow;
255 dw_out[idx] = acc.clamp(0.0, 6.0);
256 }
257 }
258 }
259 }
260
261 let pw_size = batch * self.out_ch * h_out * w_out;
263 let mut pw_out = vec![0.0_f32; pw_size];
264
265 for b in 0..batch {
266 for oc in 0..self.out_ch {
267 for oh in 0..h_out {
268 for ow in 0..w_out {
269 let mut acc = self.bias[oc];
270 for ic in 0..in_ch {
271 let dw_idx =
272 b * in_ch * h_out * w_out + ic * h_out * w_out + oh * w_out + ow;
273 let pw_idx = oc * in_ch + ic;
274 acc += dw_out[dw_idx] * self.pw_weights[pw_idx];
275 }
276 let out_idx =
277 b * self.out_ch * h_out * w_out + oc * h_out * w_out + oh * w_out + ow;
278 pw_out[out_idx] = acc.clamp(0.0, 6.0);
279 }
280 }
281 }
282 }
283
284 Ok((pw_out, [batch, self.out_ch, h_out, w_out]))
285 }
286}
287
288#[derive(Debug, Clone)]
309pub struct MobileNetV2Block {
310 in_ch: usize,
311 out_ch: usize,
312 expansion: usize,
313 stride: usize,
314 expand_pw: Option<PointwiseConv>,
316 dw: DepthwiseSeparableConv,
318 project_pw: PointwiseConv,
320 use_residual: bool,
322}
323
324impl MobileNetV2Block {
325 pub fn new(
333 in_channels: usize,
334 out_channels: usize,
335 expansion_factor: usize,
336 stride: usize,
337 ) -> Result<Self> {
338 if in_channels == 0 || out_channels == 0 {
339 return Err(NeuralError::InvalidArgument(
340 "MobileNetV2Block: channel counts must be > 0".to_string(),
341 ));
342 }
343 if stride == 0 {
344 return Err(NeuralError::InvalidArgument(
345 "MobileNetV2Block: stride must be >= 1".to_string(),
346 ));
347 }
348
349 let expanded_ch = in_channels * expansion_factor;
350
351 let expand_pw = if expansion_factor != 1 {
353 Some(PointwiseConv::new(in_channels, expanded_ch)?)
354 } else {
355 None
356 };
357
358 let dw = DepthwiseSeparableConv::new(expanded_ch, expanded_ch, (3, 3))?;
360 let project_pw = PointwiseConv::new(expanded_ch, out_channels)?;
362
363 let use_residual = stride == 1 && in_channels == out_channels;
364
365 Ok(Self {
366 in_ch: in_channels,
367 out_ch: out_channels,
368 expansion: expansion_factor,
369 stride,
370 expand_pw,
371 dw,
372 project_pw,
373 use_residual,
374 })
375 }
376
377 pub fn in_channels(&self) -> usize {
379 self.in_ch
380 }
381
382 pub fn out_channels(&self) -> usize {
384 self.out_ch
385 }
386
387 pub fn expansion(&self) -> usize {
389 self.expansion
390 }
391
392 pub fn stride(&self) -> usize {
394 self.stride
395 }
396
397 pub fn has_residual(&self) -> bool {
399 self.use_residual
400 }
401
402 pub fn parameter_count(&self) -> usize {
404 let expand = self
405 .expand_pw
406 .as_ref()
407 .map(|p| p.parameter_count())
408 .unwrap_or(0);
409 expand + self.dw.parameter_count() + self.project_pw.parameter_count()
410 }
411
412 pub fn forward(&self, input: &[f32], shape: [usize; 4]) -> Result<(Vec<f32>, [usize; 4])> {
417 let [batch, in_ch, h, w] = shape;
418 if in_ch != self.in_ch {
419 return Err(NeuralError::ShapeMismatch(format!(
420 "MobileNetV2Block: expected in_ch={}, got {}",
421 self.in_ch, in_ch
422 )));
423 }
424
425 let (expanded, expanded_shape) = if let Some(ref pw) = self.expand_pw {
427 pw.forward_with_relu6(input, shape)?
428 } else {
429 (input.to_vec(), shape)
430 };
431
432 let (dw_out, dw_shape) = depthwise_only(
436 &expanded,
437 expanded_shape,
438 &self.dw.dw_weights,
439 self.dw.kernel_size,
440 self.stride,
441 )?;
442
443 let (projected, proj_shape) = self
445 .project_pw
446 .forward_linear(dw_out.as_slice(), dw_shape)?;
447
448 let output = if self.use_residual {
450 input
451 .iter()
452 .zip(projected.iter())
453 .map(|(a, b)| a + b)
454 .collect()
455 } else {
456 projected
457 };
458
459 Ok((output, proj_shape))
460 }
461}
462
463#[derive(Debug, Clone)]
469struct PointwiseConv {
470 in_ch: usize,
471 out_ch: usize,
472 weights: Vec<f32>, bias: Vec<f32>, }
475
476impl PointwiseConv {
477 fn new(in_channels: usize, out_channels: usize) -> Result<Self> {
478 let size = out_channels * in_channels;
479 let scale = (2.0_f32 / in_channels as f32).sqrt();
480 Ok(Self {
481 in_ch: in_channels,
482 out_ch: out_channels,
483 weights: pseudo_random_weights(size, scale, 3),
484 bias: vec![0.0_f32; out_channels],
485 })
486 }
487
488 fn parameter_count(&self) -> usize {
489 self.weights.len() + self.bias.len()
490 }
491
492 fn forward_with_relu6(
494 &self,
495 input: &[f32],
496 shape: [usize; 4],
497 ) -> Result<(Vec<f32>, [usize; 4])> {
498 let [batch, in_ch, h, w] = shape;
499 if in_ch != self.in_ch {
500 return Err(NeuralError::ShapeMismatch(format!(
501 "PointwiseConv: in_ch mismatch {} vs {}",
502 self.in_ch, in_ch
503 )));
504 }
505 let out_size = batch * self.out_ch * h * w;
506 let mut out = vec![0.0_f32; out_size];
507
508 for b in 0..batch {
509 for oc in 0..self.out_ch {
510 for ph in 0..h {
511 for pw_pos in 0..w {
512 let mut acc = self.bias[oc];
513 for ic in 0..in_ch {
514 let in_idx = b * in_ch * h * w + ic * h * w + ph * w + pw_pos;
515 acc += input[in_idx] * self.weights[oc * in_ch + ic];
516 }
517 let out_idx = b * self.out_ch * h * w + oc * h * w + ph * w + pw_pos;
518 out[out_idx] = acc.clamp(0.0, 6.0);
519 }
520 }
521 }
522 }
523 Ok((out, [batch, self.out_ch, h, w]))
524 }
525
526 fn forward_linear(&self, input: &[f32], shape: [usize; 4]) -> Result<(Vec<f32>, [usize; 4])> {
528 let [batch, in_ch, h, w] = shape;
529 if in_ch != self.in_ch {
530 return Err(NeuralError::ShapeMismatch(format!(
531 "PointwiseConv(linear): in_ch mismatch {} vs {}",
532 self.in_ch, in_ch
533 )));
534 }
535 let out_size = batch * self.out_ch * h * w;
536 let mut out = vec![0.0_f32; out_size];
537 for b in 0..batch {
538 for oc in 0..self.out_ch {
539 for ph in 0..h {
540 for pw_pos in 0..w {
541 let mut acc = self.bias[oc];
542 for ic in 0..in_ch {
543 let in_idx = b * in_ch * h * w + ic * h * w + ph * w + pw_pos;
544 acc += input[in_idx] * self.weights[oc * in_ch + ic];
545 }
546 let out_idx = b * self.out_ch * h * w + oc * h * w + ph * w + pw_pos;
547 out[out_idx] = acc;
548 }
549 }
550 }
551 }
552 Ok((out, [batch, self.out_ch, h, w]))
553 }
554}
555
556fn depthwise_only(
562 input: &[f32],
563 shape: [usize; 4],
564 weights: &[f32],
565 kernel_size: (usize, usize),
566 stride: usize,
567) -> Result<(Vec<f32>, [usize; 4])> {
568 let [batch, channels, h, w] = shape;
569 let (kh, kw) = kernel_size;
570 let padding = (kh / 2, kw / 2);
571 let h_out = if stride == 1 {
572 h
573 } else {
574 (h + 2 * padding.0).saturating_sub(kh) / stride + 1
575 };
576 let w_out = if stride == 1 {
577 w
578 } else {
579 (w + 2 * padding.1).saturating_sub(kw) / stride + 1
580 };
581
582 let mut out = vec![0.0_f32; batch * channels * h_out * w_out];
583 for b in 0..batch {
584 for c in 0..channels {
585 for oh in 0..h_out {
586 for ow in 0..w_out {
587 let mut acc = 0.0_f32;
588 for ki in 0..kh {
589 for kj in 0..kw {
590 let ih = oh * stride + ki;
591 let iw = ow * stride + kj;
592 let ih_src = ih.wrapping_sub(padding.0);
593 let iw_src = iw.wrapping_sub(padding.1);
594 if ih_src < h && iw_src < w {
595 let in_idx = b * channels * h * w + c * h * w + ih_src * w + iw_src;
596 let w_idx = c * kh * kw + ki * kw + kj;
597 acc += input[in_idx] * weights[w_idx];
598 }
599 }
600 }
601 let out_idx =
602 b * channels * h_out * w_out + c * h_out * w_out + oh * w_out + ow;
603 out[out_idx] = acc.clamp(0.0, 6.0);
604 }
605 }
606 }
607 }
608 Ok((out, [batch, channels, h_out, w_out]))
609}
610
611pub struct MobileOptimizer {
619 pub size_budget_kb: f64,
621 pub max_accuracy_drop: f64,
623}
624
625impl MobileOptimizer {
626 pub fn new(size_budget_kb: f64, max_accuracy_drop: f64) -> Result<Self> {
628 if size_budget_kb <= 0.0 {
629 return Err(NeuralError::InvalidArgument(
630 "size_budget_kb must be > 0".to_string(),
631 ));
632 }
633 Ok(Self {
634 size_budget_kb,
635 max_accuracy_drop: max_accuracy_drop.clamp(0.0, 1.0),
636 })
637 }
638
639 pub fn estimate_size_bytes(num_weights: usize, bits_per_weight: u8) -> usize {
641 (num_weights * bits_per_weight as usize).div_ceil(8)
642 }
643
644 pub fn quantize_int8(weights: &[f32]) -> Result<(Vec<i8>, f32)> {
648 if weights.is_empty() {
649 return Err(NeuralError::InvalidArgument(
650 "quantize_int8: empty weights".to_string(),
651 ));
652 }
653 let abs_max = weights.iter().fold(0.0_f32, |acc, &v| acc.max(v.abs()));
654 let scale = if abs_max > 0.0 { abs_max / 127.0 } else { 1.0 };
655 let quantized: Vec<i8> = weights
656 .iter()
657 .map(|&w| (w / scale).round().clamp(-128.0, 127.0) as i8)
658 .collect();
659 Ok((quantized, scale))
660 }
661
662 pub fn magnitude_prune(weights: &mut [f32], sparsity: f64) {
664 if weights.is_empty() || sparsity <= 0.0 {
665 return;
666 }
667 let n = weights.len();
668 let mut sorted_abs: Vec<f32> = weights.iter().map(|v| v.abs()).collect();
669 sorted_abs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
670 let cutoff_idx = ((sparsity.clamp(0.0, 1.0) * n as f64) as usize).min(n.saturating_sub(1));
671 let threshold = sorted_abs[cutoff_idx];
672 for w in weights.iter_mut() {
673 if w.abs() < threshold {
674 *w = 0.0;
675 }
676 }
677 }
678
679 pub fn fits_budget(&self, param_count: usize) -> bool {
681 let bytes = Self::estimate_size_bytes(param_count, 32);
682 (bytes as f64 / 1024.0) <= self.size_budget_kb
683 }
684}
685
686fn pseudo_random_weights(n: usize, scale: f32, seed_offset: u64) -> Vec<f32> {
692 let mut state: u64 = 0xDEAD_BEEF_0000_0001u64.wrapping_add(seed_offset);
693 (0..n)
694 .map(|_| {
695 state = state
696 .wrapping_mul(6364136223846793005)
697 .wrapping_add(1442695040888963407);
698 let u = (state >> 33) as f32 / u32::MAX as f32; (u * 2.0 - 1.0) * scale
700 })
701 .collect()
702}
703
704#[cfg(test)]
709mod tests {
710 use super::*;
711
712 #[test]
713 fn test_mobile_net_config_v1() {
714 let cfg = MobileNetConfig::mobilenet_v1();
715 assert_eq!(cfg.input_resolution, 224);
716 assert_eq!(cfg.version, MobileNetVersion::V1);
717 assert!((cfg.width_multiplier - 1.0).abs() < 1e-6);
718 }
719
720 #[test]
721 fn test_mobile_net_config_v2() {
722 let cfg = MobileNetConfig::mobilenet_v2();
723 assert_eq!(cfg.version, MobileNetVersion::V2);
724 }
725
726 #[test]
727 fn test_scaled_channels() {
728 let cfg = MobileNetConfig {
729 width_multiplier: 0.5,
730 ..MobileNetConfig::mobilenet_v2()
731 };
732 assert_eq!(cfg.scaled_channels(32), 16);
733 assert_eq!(cfg.scaled_channels(64), 32);
734 }
735
736 #[test]
737 fn test_depthwise_separable_conv_creation() {
738 let dsc = DepthwiseSeparableConv::new(4, 8, (3, 3)).expect("dsc ok");
739 assert_eq!(dsc.in_channels(), 4);
740 assert_eq!(dsc.out_channels(), 8);
741 assert!(dsc.parameter_count() > 0);
742 }
743
744 #[test]
745 fn test_depthwise_separable_conv_forward() {
746 let dsc = DepthwiseSeparableConv::new(2, 4, (3, 3)).expect("dsc ok");
747 let input = vec![0.5_f32; 2 * 8 * 8];
749 let (output, out_shape) = dsc.forward(&input, [1, 2, 8, 8]).expect("forward ok");
750 let [b, c, h, w] = out_shape;
751 assert_eq!(b, 1);
752 assert_eq!(c, 4);
753 assert_eq!(h, 8); assert_eq!(w, 8);
755 assert_eq!(output.len(), b * c * h * w);
756 }
757
758 #[test]
759 fn test_depthwise_separable_conv_channel_mismatch_err() {
760 let dsc = DepthwiseSeparableConv::new(4, 8, (3, 3)).expect("dsc ok");
761 let input = vec![0.0_f32; 2 * 4 * 4]; let result = dsc.forward(&input, [1, 2, 4, 4]);
763 assert!(result.is_err());
764 }
765
766 #[test]
767 fn test_mobilenet_v2_block_creation() {
768 let block = MobileNetV2Block::new(32, 16, 6, 1).expect("block ok");
769 assert_eq!(block.in_channels(), 32);
770 assert_eq!(block.out_channels(), 16);
771 assert!(!block.has_residual()); }
773
774 #[test]
775 fn test_mobilenet_v2_block_residual() {
776 let block = MobileNetV2Block::new(16, 16, 6, 1).expect("block ok");
777 assert!(block.has_residual());
778 }
779
780 #[test]
781 fn test_mobilenet_v2_block_forward() {
782 let block = MobileNetV2Block::new(8, 8, 6, 1).expect("block ok");
783 let input = vec![0.1_f32; 8 * 4 * 4]; let (output, out_shape) = block.forward(&input, [1, 8, 4, 4]).expect("fwd ok");
785 let [b, c, _h, _w] = out_shape;
786 assert_eq!(b, 1);
787 assert_eq!(c, 8);
788 assert_eq!(output.len(), 8 * 4 * 4);
789 }
790
791 #[test]
792 fn test_mobilenet_v2_block_stride2() {
793 let block = MobileNetV2Block::new(8, 16, 6, 2).expect("block ok");
794 assert!(!block.has_residual());
795 let input = vec![0.1_f32; 8 * 8 * 8]; let (output, out_shape) = block.forward(&input, [1, 8, 8, 8]).expect("fwd ok");
797 let [b, c, h, w] = out_shape;
798 assert_eq!(b, 1);
799 assert_eq!(c, 16);
800 assert!(h <= 4 && w <= 4, "expected ≤4, got h={h} w={w}");
802 assert_eq!(output.len(), b * c * h * w);
803 }
804
805 #[test]
806 fn test_mobile_optimizer_quantize_int8() {
807 let weights = vec![0.5_f32, -0.5, 1.0, -1.0, 0.0];
808 let (q, scale) = MobileOptimizer::quantize_int8(&weights).expect("ok");
809 assert_eq!(q.len(), weights.len());
810 let dequant: Vec<f32> = q.iter().map(|&v| v as f32 * scale).collect();
811 for (orig, deq) in weights.iter().zip(dequant.iter()) {
812 assert!((orig - deq).abs() < 0.01, "orig={orig} deq={deq}");
813 }
814 }
815
816 #[test]
817 fn test_mobile_optimizer_prune() {
818 let mut weights = vec![0.01_f32, 0.5, 0.001, 1.0, 0.002];
819 MobileOptimizer::magnitude_prune(&mut weights, 0.6);
820 let zeros = weights.iter().filter(|&&v| v == 0.0).count();
822 assert!(zeros >= 2, "expected ≥2 zeros, got {zeros}");
823 }
824
825 #[test]
826 fn test_mobile_optimizer_budget() {
827 let opt = MobileOptimizer::new(1000.0, 0.01).expect("ok");
828 assert!(opt.fits_budget(10));
830 assert!(!opt.fits_budget(10_000_000));
832 }
833
834 #[test]
835 fn test_depthwise_separable_conv_zero_channels_err() {
836 assert!(DepthwiseSeparableConv::new(0, 8, (3, 3)).is_err());
837 }
838}