1use crate::scirs2_compat::random::legacy;
7use crate::{MobileBackend, MobileConfig, PerformanceTier};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use trustformers_core::errors::Result;
11
12fn random_usize(max: usize) -> usize {
14 if max == 0 {
15 return 0;
16 }
17 ((legacy::f64() * max as f64) as usize).min(max.saturating_sub(1))
18}
19
20fn random_f32() -> f32 {
21 legacy::f32()
22}
23
24#[derive(Debug, Clone)]
26pub struct MobileNAS {
27 search_config: NASConfig,
28 architecture_candidates: Vec<MobileArchitecture>,
29 performance_history: Vec<PerformanceRecord>,
30 optimization_agent: ReinforcementLearningAgent,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct NASConfig {
36 pub max_iterations: usize,
38 pub optimization_targets: Vec<OptimizationTarget>,
40 pub device_constraints: DeviceConstraints,
42 pub search_strategy: SearchStrategy,
44 pub early_stopping: EarlyStoppingConfig,
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
50pub enum OptimizationTarget {
51 Latency,
53 Memory,
55 Power,
57 Accuracy,
59 ModelSize,
61 Energy,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct DeviceConstraints {
68 pub max_memory_mb: usize,
70 pub max_latency_ms: f32,
72 pub performance_tier: PerformanceTier,
74 pub available_backends: Vec<MobileBackend>,
76 pub power_budget_mw: f32,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub enum SearchStrategy {
83 Random,
85 Evolutionary {
87 population_size: usize,
88 mutation_rate: f32,
89 crossover_rate: f32,
90 },
91 ReinforcementLearning {
93 learning_rate: f32,
94 exploration_rate: f32,
95 replay_buffer_size: usize,
96 },
97 Differentiable {
99 temperature: f32,
100 gumbel_softmax: bool,
101 },
102 Progressive {
104 stages: usize,
105 pruning_threshold: f32,
106 },
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct EarlyStoppingConfig {
112 pub patience: usize,
114 pub min_improvement: f32,
116 pub monitor_metric: OptimizationTarget,
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct MobileArchitecture {
123 pub id: String,
125 pub layers: Vec<LayerConfig>,
127 pub skip_connections: Vec<SkipConnection>,
129 pub quantization: QuantizationConfig,
131 pub estimated_metrics: Option<ArchitectureMetrics>,
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct LayerConfig {
138 pub layer_type: LayerType,
140 pub input_dim: Vec<usize>,
142 pub output_dim: Vec<usize>,
144 pub parameters: HashMap<String, f32>,
146 pub activation: ActivationType,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
152pub enum LayerType {
153 DepthwiseSeparableConv {
155 kernel_size: usize,
156 stride: usize,
157 dilation: usize,
158 },
159 MobileBottleneck {
161 expansion_ratio: f32,
162 kernel_size: usize,
163 squeeze_excitation: bool,
164 },
165 EfficientChannelAttention {
167 reduction_ratio: usize,
168 use_gating: bool,
169 },
170 MobileMultiHeadAttention {
172 num_heads: usize,
173 head_dim: usize,
174 sparse_attention: bool,
175 },
176 GroupNormalization { num_groups: usize },
178 MobileLinear { use_bias: bool, quantized: bool },
180}
181
182#[derive(Debug, Clone, Serialize, Deserialize)]
184pub enum ActivationType {
185 Swish,
187 HardSwish,
189 ReLU6,
191 GeluApprox,
193 Mish,
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
199pub struct SkipConnection {
200 pub from_layer: usize,
202 pub to_layer: usize,
204 pub connection_type: ConnectionType,
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
210pub enum ConnectionType {
211 Residual,
213 Dense,
215 Attention { num_heads: usize },
217 ChannelShuffle,
219}
220
221#[derive(Debug, Clone, Serialize, Deserialize)]
223pub struct QuantizationConfig {
224 pub layer_schemes: HashMap<usize, QuantizationScheme>,
226 pub mixed_precision: bool,
228 pub dynamic_quantization: bool,
230}
231
232#[derive(Debug, Clone, Serialize, Deserialize)]
234pub enum QuantizationScheme {
235 Int4 { symmetric: bool },
237 Int8 { symmetric: bool },
239 FP16,
241 BlockWise { block_size: usize },
243 FP32,
245}
246
247#[derive(Debug, Clone, Serialize, Deserialize)]
249pub struct ArchitectureMetrics {
250 pub latency_ms: f32,
252 pub memory_mb: f32,
254 pub power_mw: f32,
256 pub accuracy: Option<f32>,
258 pub model_size_mb: f32,
260 pub energy_per_inference_mj: f32,
262 pub throughput_fps: f32,
264}
265
266#[derive(Debug, Clone)]
268pub struct PerformanceRecord {
269 pub architecture: MobileArchitecture,
271 pub metrics: ArchitectureMetrics,
273 pub device_config: MobileConfig,
275 pub timestamp: std::time::SystemTime,
277 pub user_context: Option<UserContext>,
279}
280
281#[derive(Debug, Clone, Serialize, Deserialize)]
283pub struct UserContext {
284 pub usage_patterns: Vec<UsagePattern>,
286 pub preferences: UserPreferences,
288 pub environment: DeviceEnvironment,
290}
291
292#[derive(Debug, Clone, Serialize, Deserialize)]
294pub struct UsagePattern {
295 pub task_type: String,
297 pub frequency: f32,
299 pub input_characteristics: InputCharacteristics,
301 pub performance_requirements: PerformanceRequirements,
303}
304
305#[derive(Debug, Clone, Serialize, Deserialize)]
307pub struct InputCharacteristics {
308 pub input_sizes: Vec<Vec<usize>>,
310 pub common_batch_sizes: Vec<usize>,
312 pub data_types: Vec<String>,
314}
315
316#[derive(Debug, Clone, Serialize, Deserialize)]
318pub struct PerformanceRequirements {
319 pub max_latency_ms: f32,
321 pub battery_importance: f32,
323 pub accuracy_importance: f32,
325}
326
327#[derive(Debug, Clone, Serialize, Deserialize)]
329pub struct UserPreferences {
330 pub primary_target: OptimizationTarget,
332 pub secondary_targets: Vec<OptimizationTarget>,
334 pub quality_tradeoffs: QualityTradeoffs,
336}
337
338#[derive(Debug, Clone, Serialize, Deserialize)]
340pub struct QualityTradeoffs {
341 pub max_accuracy_loss: f32,
343 pub max_latency_increase: f32,
345 pub max_memory_increase: f32,
347}
348
349#[derive(Debug, Clone, Serialize, Deserialize)]
351pub struct DeviceEnvironment {
352 pub charging_status: ChargingPattern,
354 pub network_patterns: NetworkPattern,
356 pub thermal_environment: ThermalEnvironment,
358}
359
360#[derive(Debug, Clone, Serialize, Deserialize)]
362pub enum ChargingPattern {
363 FrequentCharging,
365 ModerateCharging,
367 InfrequentCharging,
369}
370
371#[derive(Debug, Clone, Serialize, Deserialize)]
373pub enum NetworkPattern {
374 PrimarilyWiFi,
376 Mixed,
378 PrimarilyCellular,
380 FrequentOffline,
382}
383
384#[derive(Debug, Clone, Serialize, Deserialize)]
386pub enum ThermalEnvironment {
387 Cool,
389 Moderate,
391 Warm,
393 Variable,
395}
396
397#[derive(Debug, Clone)]
399pub struct ReinforcementLearningAgent {
400 config: RLConfig,
402 q_network: QNetwork,
404 replay_buffer: Vec<Experience>,
406 exploration_rate: f32,
408}
409
410#[derive(Debug, Clone)]
412pub struct RLConfig {
413 pub learning_rate: f32,
415 pub discount_factor: f32,
417 pub initial_exploration_rate: f32,
419 pub exploration_decay: f32,
421 pub min_exploration_rate: f32,
423}
424
425#[derive(Debug, Clone)]
427pub struct QNetwork {
428 weights: Vec<Vec<f32>>,
430 architecture: Vec<usize>,
432}
433
434#[derive(Debug, Clone)]
436pub struct Experience {
437 pub state: Vec<f32>,
439 pub action: ArchitectureAction,
441 pub reward: f32,
443 pub next_state: Vec<f32>,
445 pub done: bool,
447}
448
449#[derive(Debug, Clone)]
451pub enum ArchitectureAction {
452 AddLayer {
454 layer_type: LayerType,
455 position: usize,
456 },
457 RemoveLayer { position: usize },
459 ModifyLayer {
461 position: usize,
462 parameter: String,
463 value: f32,
464 },
465 ChangeQuantization {
467 layer: usize,
468 scheme: QuantizationScheme,
469 },
470 AddSkipConnection {
472 from: usize,
473 to: usize,
474 connection_type: ConnectionType,
475 },
476 RemoveSkipConnection { from: usize, to: usize },
478}
479
480impl MobileNAS {
481 pub fn new(config: NASConfig) -> Self {
483 let rl_config = RLConfig {
484 learning_rate: 0.001,
485 discount_factor: 0.99,
486 initial_exploration_rate: 1.0,
487 exploration_decay: 0.995,
488 min_exploration_rate: 0.1,
489 };
490
491 Self {
492 search_config: config,
493 architecture_candidates: Vec::new(),
494 performance_history: Vec::new(),
495 optimization_agent: ReinforcementLearningAgent::new(rl_config),
496 }
497 }
498
499 pub fn search_optimal_architecture(
501 &mut self,
502 base_architecture: MobileArchitecture,
503 user_context: Option<UserContext>,
504 ) -> Result<MobileArchitecture> {
505 let mut best_architecture = base_architecture.clone();
506 let mut best_score = f32::NEG_INFINITY;
507 let mut iterations_without_improvement = 0;
508
509 for iteration in 0..self.search_config.max_iterations {
510 let candidate = match &self.search_config.search_strategy {
512 SearchStrategy::Random => self.generate_random_architecture(&base_architecture)?,
513 SearchStrategy::Evolutionary { .. } => {
514 self.evolve_architecture(&best_architecture)?
515 },
516 SearchStrategy::ReinforcementLearning { .. } => {
517 self.rl_generate_architecture(&best_architecture)?
518 },
519 SearchStrategy::Differentiable { .. } => {
520 self.differentiable_search(&best_architecture)?
521 },
522 SearchStrategy::Progressive { .. } => {
523 self.progressive_search(&best_architecture, iteration)?
524 },
525 };
526
527 let metrics = self.evaluate_architecture(&candidate)?;
529 let score = self.calculate_fitness_score(&metrics, &user_context)?;
530
531 if score > best_score {
533 best_score = score;
534 best_architecture = candidate.clone();
535 iterations_without_improvement = 0;
536
537 let record = PerformanceRecord {
539 architecture: candidate,
540 metrics,
541 device_config: MobileConfig::default(), timestamp: std::time::SystemTime::now(),
543 user_context: user_context.clone(),
544 };
545 self.performance_history.push(record);
546 } else {
547 iterations_without_improvement += 1;
548 }
549
550 if iterations_without_improvement >= self.search_config.early_stopping.patience {
552 println!(
553 "Early stopping at iteration {} due to no improvement",
554 iteration
555 );
556 break;
557 }
558
559 if matches!(
561 self.search_config.search_strategy,
562 SearchStrategy::ReinforcementLearning { .. }
563 ) {
564 self.optimization_agent.update_from_experience(score)?;
565 }
566 }
567
568 Ok(best_architecture)
569 }
570
571 fn generate_random_architecture(
573 &self,
574 base: &MobileArchitecture,
575 ) -> Result<MobileArchitecture> {
576 let mut candidate = base.clone();
577
578 for _ in 0..3 {
580 match random_usize(4) {
581 0 => self.mutate_layer_params(&mut candidate)?,
582 1 => self.mutate_quantization(&mut candidate)?,
583 2 => self.mutate_skip_connections(&mut candidate)?,
584 _ => self.mutate_architecture_structure(&mut candidate)?,
585 }
586 }
587
588 Ok(candidate)
589 }
590
591 fn evolve_architecture(&self, parent: &MobileArchitecture) -> Result<MobileArchitecture> {
593 let mut offspring = parent.clone();
595
596 if random_f32() < 0.3 {
598 self.mutate_layer_params(&mut offspring)?;
599 }
600 if random_f32() < 0.2 {
601 self.mutate_quantization(&mut offspring)?;
602 }
603 if random_f32() < 0.1 {
604 self.mutate_skip_connections(&mut offspring)?;
605 }
606
607 Ok(offspring)
608 }
609
610 fn rl_generate_architecture(
612 &mut self,
613 current: &MobileArchitecture,
614 ) -> Result<MobileArchitecture> {
615 let state = self.encode_architecture_state(current)?;
616 let action = self.optimization_agent.select_action(&state)?;
617 let mut new_architecture = current.clone();
618
619 self.apply_architecture_action(&mut new_architecture, action)?;
620
621 Ok(new_architecture)
622 }
623
624 fn differentiable_search(&self, base: &MobileArchitecture) -> Result<MobileArchitecture> {
626 let mut candidate = base.clone();
628
629 for layer in &mut candidate.layers {
631 if let Some(param) = layer.parameters.get_mut("channels") {
633 *param *= 1.0 + (random_f32() - 0.5) * 0.1; }
635 }
636
637 Ok(candidate)
638 }
639
640 fn progressive_search(
642 &self,
643 base: &MobileArchitecture,
644 iteration: usize,
645 ) -> Result<MobileArchitecture> {
646 let mut candidate = base.clone();
647
648 let stage = iteration / (self.search_config.max_iterations / 4);
650 match stage {
651 0 => self.mutate_layer_params(&mut candidate)?,
652 1 => self.mutate_quantization(&mut candidate)?,
653 2 => self.mutate_skip_connections(&mut candidate)?,
654 _ => self.mutate_architecture_structure(&mut candidate)?,
655 }
656
657 Ok(candidate)
658 }
659
660 fn evaluate_architecture(
662 &self,
663 architecture: &MobileArchitecture,
664 ) -> Result<ArchitectureMetrics> {
665 let mut total_params = 0;
667 let mut total_flops = 0;
668 let mut memory_usage = 0;
669
670 for layer in &architecture.layers {
671 let (params, flops, memory) = self.estimate_layer_metrics(layer)?;
672 total_params += params;
673 total_flops += flops;
674 memory_usage += memory;
675 }
676
677 let latency_ms = self.estimate_latency(total_flops, &architecture.quantization)?;
679 let memory_mb = memory_usage as f32 / (1024.0 * 1024.0);
680 let power_mw = self.estimate_power_consumption(total_flops, latency_ms)?;
681 let model_size_mb = (total_params * 4) as f32 / (1024.0 * 1024.0); let energy_per_inference_mj = power_mw * latency_ms;
683 let throughput_fps = 1000.0 / latency_ms;
684
685 Ok(ArchitectureMetrics {
686 latency_ms,
687 memory_mb,
688 power_mw,
689 accuracy: None, model_size_mb,
691 energy_per_inference_mj,
692 throughput_fps,
693 })
694 }
695
696 fn calculate_fitness_score(
698 &self,
699 metrics: &ArchitectureMetrics,
700 user_context: &Option<UserContext>,
701 ) -> Result<f32> {
702 let mut score = 0.0;
703 let mut total_weight = 0.0;
704
705 for &target in &self.search_config.optimization_targets {
707 let (value, weight) = match target {
708 OptimizationTarget::Latency => {
709 let normalized = 1.0 / (1.0 + metrics.latency_ms / 100.0);
710 (normalized, 1.0)
711 },
712 OptimizationTarget::Memory => {
713 let normalized = 1.0 / (1.0 + metrics.memory_mb / 512.0);
714 (normalized, 1.0)
715 },
716 OptimizationTarget::Power => {
717 let normalized = 1.0 / (1.0 + metrics.power_mw / 1000.0);
718 (normalized, 1.0)
719 },
720 OptimizationTarget::ModelSize => {
721 let normalized = 1.0 / (1.0 + metrics.model_size_mb / 100.0);
722 (normalized, 1.0)
723 },
724 OptimizationTarget::Energy => {
725 let normalized = 1.0 / (1.0 + metrics.energy_per_inference_mj / 10.0);
726 (normalized, 1.0)
727 },
728 OptimizationTarget::Accuracy => {
729 let normalized = metrics.accuracy.unwrap_or(0.8);
730 (normalized, 2.0) },
732 };
733
734 score += value * weight;
735 total_weight += weight;
736 }
737
738 if let Some(ref context) = user_context {
740 score = self.adjust_score_for_user_context(score, metrics, context)?;
741 }
742
743 score = self.apply_constraint_penalties(score, metrics)?;
745
746 Ok(score / total_weight)
747 }
748
749 fn adjust_score_for_user_context(
751 &self,
752 base_score: f32,
753 metrics: &ArchitectureMetrics,
754 context: &UserContext,
755 ) -> Result<f32> {
756 let mut adjusted_score = base_score;
757
758 match context.preferences.primary_target {
760 OptimizationTarget::Latency if metrics.latency_ms > 50.0 => {
761 adjusted_score *= 0.8; },
763 OptimizationTarget::Memory if metrics.memory_mb > 256.0 => {
764 adjusted_score *= 0.8; },
766 OptimizationTarget::Power if metrics.power_mw > 500.0 => {
767 adjusted_score *= 0.8; },
769 _ => {},
770 }
771
772 for pattern in &context.usage_patterns {
774 if pattern.frequency > 0.5
775 && metrics.latency_ms > pattern.performance_requirements.max_latency_ms
776 {
777 adjusted_score *= 0.9; }
779 }
780
781 Ok(adjusted_score)
782 }
783
784 fn apply_constraint_penalties(
786 &self,
787 base_score: f32,
788 metrics: &ArchitectureMetrics,
789 ) -> Result<f32> {
790 let mut score = base_score;
791
792 if metrics.memory_mb > self.search_config.device_constraints.max_memory_mb as f32 {
794 score *= 0.5; }
796
797 if metrics.latency_ms > self.search_config.device_constraints.max_latency_ms {
799 score *= 0.5; }
801
802 if metrics.power_mw > self.search_config.device_constraints.power_budget_mw {
804 score *= 0.7; }
806
807 Ok(score)
808 }
809
810 fn mutate_layer_params(&self, architecture: &mut MobileArchitecture) -> Result<()> {
812 if !architecture.layers.is_empty() {
813 let layer_idx = random_usize(architecture.layers.len());
814 let layer = &mut architecture.layers[layer_idx];
815
816 if !layer.parameters.is_empty() {
818 let keys: Vec<_> = layer.parameters.keys().cloned().collect();
819 let param_key = &keys[random_usize(keys.len())];
820 if let Some(value) = layer.parameters.get_mut(param_key) {
821 *value *= 1.0 + (random_f32() - 0.5) * 0.2; }
823 }
824 }
825 Ok(())
826 }
827
828 fn mutate_quantization(&self, architecture: &mut MobileArchitecture) -> Result<()> {
829 if !architecture.layers.is_empty() {
830 let layer_idx = random_usize(architecture.layers.len());
831 let schemes = [
832 QuantizationScheme::Int4 { symmetric: true },
833 QuantizationScheme::Int8 { symmetric: true },
834 QuantizationScheme::FP16,
835 QuantizationScheme::FP32,
836 ];
837 let scheme = schemes[random_usize(schemes.len())].clone();
838 architecture.quantization.layer_schemes.insert(layer_idx, scheme);
839 }
840 Ok(())
841 }
842
843 fn mutate_skip_connections(&self, _architecture: &mut MobileArchitecture) -> Result<()> {
844 Ok(())
846 }
847
848 fn mutate_architecture_structure(&self, _architecture: &mut MobileArchitecture) -> Result<()> {
849 Ok(())
851 }
852
853 fn estimate_layer_metrics(&self, layer: &LayerConfig) -> Result<(usize, usize, usize)> {
854 let params =
856 layer.input_dim.iter().product::<usize>() * layer.output_dim.iter().product::<usize>();
857 let flops = params * 2; let memory = params * 4; Ok((params, flops, memory))
860 }
861
862 fn estimate_latency(
863 &self,
864 total_flops: usize,
865 _quantization: &QuantizationConfig,
866 ) -> Result<f32> {
867 let base_latency = total_flops as f32 / 1_000_000.0; Ok(base_latency)
870 }
871
872 fn estimate_power_consumption(&self, total_flops: usize, latency_ms: f32) -> Result<f32> {
873 let power = (total_flops as f32 / 1_000_000.0) * 100.0 + latency_ms * 10.0;
875 Ok(power)
876 }
877
878 fn encode_architecture_state(&self, _architecture: &MobileArchitecture) -> Result<Vec<f32>> {
879 Ok(vec![0.5; 128]) }
882
883 fn apply_architecture_action(
884 &self,
885 _architecture: &mut MobileArchitecture,
886 _action: ArchitectureAction,
887 ) -> Result<()> {
888 Ok(())
890 }
891}
892
893impl ReinforcementLearningAgent {
894 fn new(config: RLConfig) -> Self {
895 Self {
896 exploration_rate: config.initial_exploration_rate,
897 config,
898 q_network: QNetwork {
899 weights: vec![vec![0.0; 128]; 64], architecture: vec![128, 64, 32, 16],
901 },
902 replay_buffer: Vec::new(),
903 }
904 }
905
906 fn select_action(&mut self, _state: &[f32]) -> Result<ArchitectureAction> {
907 let actions = vec![
909 ArchitectureAction::ModifyLayer {
910 position: 0,
911 parameter: "channels".to_string(),
912 value: 64.0,
913 },
914 ];
916
917 let action_idx = if random_f32() < self.exploration_rate {
918 random_usize(actions.len())
920 } else {
921 0 };
924
925 Ok(actions[action_idx].clone())
926 }
927
928 fn update_from_experience(&mut self, reward: f32) -> Result<()> {
929 self.exploration_rate = (self.exploration_rate * self.config.exploration_decay)
931 .max(self.config.min_exploration_rate);
932
933 Ok(())
937 }
938}
939
940impl Default for NASConfig {
941 fn default() -> Self {
942 Self {
943 max_iterations: 100,
944 optimization_targets: vec![
945 OptimizationTarget::Latency,
946 OptimizationTarget::Memory,
947 OptimizationTarget::Power,
948 ],
949 device_constraints: DeviceConstraints {
950 max_memory_mb: 512,
951 max_latency_ms: 100.0,
952 performance_tier: PerformanceTier::Mid,
953 available_backends: vec![MobileBackend::CPU, MobileBackend::GPU],
954 power_budget_mw: 1000.0,
955 },
956 search_strategy: SearchStrategy::Evolutionary {
957 population_size: 20,
958 mutation_rate: 0.1,
959 crossover_rate: 0.7,
960 },
961 early_stopping: EarlyStoppingConfig {
962 patience: 10,
963 min_improvement: 0.01,
964 monitor_metric: OptimizationTarget::Latency,
965 },
966 }
967 }
968}
969
970#[cfg(test)]
971mod tests {
972 use super::*;
973
974 #[test]
975 fn test_mobile_nas_creation() {
976 let config = NASConfig::default();
977 let nas = MobileNAS::new(config);
978 assert_eq!(nas.architecture_candidates.len(), 0);
979 }
980
981 #[test]
982 fn test_architecture_metrics() {
983 let metrics = ArchitectureMetrics {
984 latency_ms: 50.0,
985 memory_mb: 128.0,
986 power_mw: 500.0,
987 accuracy: Some(0.9),
988 model_size_mb: 25.0,
989 energy_per_inference_mj: 25.0,
990 throughput_fps: 20.0,
991 };
992
993 assert_eq!(metrics.latency_ms, 50.0);
994 assert_eq!(metrics.throughput_fps, 20.0);
995 }
996
997 #[test]
998 fn test_nas_config_default() {
999 let config = NASConfig::default();
1000 assert_eq!(config.max_iterations, 100);
1001 assert!(config.optimization_targets.contains(&OptimizationTarget::Latency));
1002 }
1003}