1use anyhow::Result;
11use scirs2_core::random::*; use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::time::{Duration, Instant};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct NASConfig {
19 pub algorithm: NASAlgorithm,
21 pub search_space: SearchSpaceConfig,
23 pub hardware_constraints: HardwareConstraints,
25 pub objectives: Vec<Objective>,
27 pub max_search_time: Duration,
29 pub max_architectures: usize,
31 pub early_stopping: EarlyStoppingConfig,
33 pub progressive_search: bool,
35 pub hardware_aware: bool,
37 pub multi_objective: bool,
39}
40
41impl Default for NASConfig {
42 fn default() -> Self {
43 Self {
44 algorithm: NASAlgorithm::DARTS,
45 search_space: SearchSpaceConfig::default(),
46 hardware_constraints: HardwareConstraints::default(),
47 objectives: vec![Objective::Accuracy, Objective::Efficiency],
48 max_search_time: Duration::from_secs(3600 * 24), max_architectures: 1000,
50 early_stopping: EarlyStoppingConfig::default(),
51 progressive_search: true,
52 hardware_aware: true,
53 multi_objective: true,
54 }
55 }
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60pub enum NASAlgorithm {
61 DARTS, GDAS, ENAS, ProxylessNAS, Progressive, Evolutionary, Random, }
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct SearchSpaceConfig {
73 pub operations: Vec<Operation>,
75 pub depth_range: (usize, usize),
77 pub width_range: (f32, f32),
79 pub activations: Vec<Activation>,
81 pub attention_types: Vec<AttentionType>,
83 pub normalizations: Vec<Normalization>,
85}
86
87impl Default for SearchSpaceConfig {
88 fn default() -> Self {
89 Self {
90 operations: vec![
91 Operation::Conv1x1,
92 Operation::Conv3x3,
93 Operation::SeparableConv3x3,
94 Operation::DilatedConv3x3,
95 Operation::MobileConv,
96 Operation::Identity,
97 Operation::MaxPool,
98 Operation::AvgPool,
99 ],
100 depth_range: (12, 48),
101 width_range: (0.5, 2.0),
102 activations: vec![
103 Activation::ReLU,
104 Activation::GELU,
105 Activation::Swish,
106 Activation::Mish,
107 ],
108 attention_types: vec![
109 AttentionType::MultiHead,
110 AttentionType::GroupedQuery,
111 AttentionType::FlashAttention,
112 AttentionType::LinearAttention,
113 ],
114 normalizations: vec![
115 Normalization::LayerNorm,
116 Normalization::RMSNorm,
117 Normalization::BatchNorm,
118 ],
119 }
120 }
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct HardwareConstraints {
126 pub max_parameters: usize,
128 pub max_memory: usize,
130 pub max_latency: f32,
132 pub max_flops: usize,
134 pub target_platform: TargetPlatform,
136 pub max_power: f32,
138}
139
140impl Default for HardwareConstraints {
141 fn default() -> Self {
142 Self {
143 max_parameters: 1_000_000_000, max_memory: 8_000_000_000, max_latency: 100.0, max_flops: 1_000_000_000_000, target_platform: TargetPlatform::GPU,
148 max_power: 250.0, }
150 }
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
155pub enum Operation {
156 Conv1x1,
157 Conv3x3,
158 SeparableConv3x3,
159 DilatedConv3x3,
160 MobileConv,
161 Identity,
162 MaxPool,
163 AvgPool,
164 GlobalAvgPool,
165 Linear,
166 Embedding,
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
171pub enum Activation {
172 ReLU,
173 GELU,
174 Swish,
175 Mish,
176 Tanh,
177 Sigmoid,
178 LeakyReLU,
179}
180
181#[derive(Debug, Clone, Serialize, Deserialize)]
183pub enum AttentionType {
184 MultiHead,
185 GroupedQuery,
186 FlashAttention,
187 LinearAttention,
188 SparseAttention,
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
193pub enum Normalization {
194 LayerNorm,
195 RMSNorm,
196 BatchNorm,
197 GroupNorm,
198}
199
200#[derive(Debug, Clone, Serialize, Deserialize)]
202pub enum TargetPlatform {
203 CPU,
204 GPU,
205 TPU,
206 Mobile,
207 Edge,
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
212pub enum Objective {
213 Accuracy,
214 Efficiency,
215 Latency,
216 Memory,
217 Power,
218 FLOPS,
219}
220
221#[derive(Debug, Clone, Serialize, Deserialize)]
223pub struct EarlyStoppingConfig {
224 pub patience: usize,
226 pub min_improvement: f32,
228 pub enabled: bool,
230}
231
232impl Default for EarlyStoppingConfig {
233 fn default() -> Self {
234 Self {
235 patience: 10,
236 min_improvement: 0.01,
237 enabled: true,
238 }
239 }
240}
241
242#[derive(Debug, Clone, Serialize, Deserialize)]
244pub struct Architecture {
245 pub id: String,
247 pub encoding: Vec<LayerSpec>,
249 pub metrics: PerformanceMetrics,
251 pub hardware_metrics: HardwareMetrics,
253 pub training_history: Vec<TrainingMetric>,
255}
256
257#[derive(Debug, Clone, Serialize, Deserialize)]
259pub struct LayerSpec {
260 pub layer_type: LayerType,
262 pub parameters: HashMap<String, f32>,
264 pub dimensions: (usize, usize),
266}
267
268#[derive(Debug, Clone, Serialize, Deserialize)]
270pub enum LayerType {
271 Transformer,
272 Convolution,
273 Attention,
274 MLP,
275 Normalization,
276 Activation,
277 Pooling,
278 Embedding,
279}
280
281#[derive(Debug, Clone, Serialize, Deserialize)]
283pub struct PerformanceMetrics {
284 pub accuracy: f32,
286 pub loss: f32,
288 pub inference_time: Duration,
290 pub memory_usage: usize,
292 pub parameter_count: usize,
294 pub flops: usize,
296}
297
298#[derive(Debug, Clone, Serialize, Deserialize)]
300pub struct HardwareMetrics {
301 pub gpu_utilization: f32,
303 pub memory_bandwidth: f32,
305 pub power_consumption: f32,
307 pub temperature: f32,
309}
310
311#[derive(Debug, Clone, Serialize, Deserialize)]
313pub struct TrainingMetric {
314 pub step: usize,
316 pub loss: f32,
318 pub accuracy: f32,
320 pub learning_rate: f32,
322}
323
324#[allow(dead_code)]
326pub struct NASController {
327 config: NASConfig,
328 search_space: SearchSpace,
329 evaluated_architectures: Vec<Architecture>,
330 current_best: Option<Architecture>,
331 search_history: Vec<SearchEvent>,
332 #[allow(dead_code)]
333 predictor: PerformancePredictor,
334 optimizer: ArchitectureOptimizer,
335}
336
337impl NASController {
338 pub fn new(config: NASConfig) -> Self {
339 Self {
340 search_space: SearchSpace::new(&config.search_space),
341 config,
342 evaluated_architectures: Vec::new(),
343 current_best: None,
344 search_history: Vec::new(),
345 predictor: PerformancePredictor::new(),
346 optimizer: ArchitectureOptimizer::new(),
347 }
348 }
349
350 pub fn start_search(&mut self) -> Result<Architecture> {
352 let start_time = Instant::now();
353
354 match self.config.algorithm {
355 NASAlgorithm::DARTS => self.run_darts()?,
356 NASAlgorithm::GDAS => self.run_gdas()?,
357 NASAlgorithm::ENAS => self.run_enas()?,
358 NASAlgorithm::ProxylessNAS => self.run_proxyless_nas()?,
359 NASAlgorithm::Progressive => self.run_progressive_search()?,
360 NASAlgorithm::Evolutionary => self.run_evolutionary_search()?,
361 NASAlgorithm::Random => self.run_random_search()?,
362 }
363
364 let search_duration = start_time.elapsed();
365
366 self.search_history.push(SearchEvent {
368 timestamp: Instant::now(),
369 event_type: SearchEventType::SearchCompleted,
370 duration: search_duration,
371 architectures_evaluated: self.evaluated_architectures.len(),
372 });
373
374 self.current_best
375 .clone()
376 .ok_or_else(|| anyhow::anyhow!("No architecture found during search"))
377 }
378
379 fn run_darts(&mut self) -> Result<()> {
381 println!("Running DARTS algorithm...");
382
383 let mut architecture_weights = self.initialize_architecture_weights()?;
385
386 for _epoch in 0..100 {
388 let architecture = self.sample_architecture_from_weights(&architecture_weights)?;
390
391 let metrics = self.evaluate_architecture(&architecture)?;
393
394 self.update_architecture_weights(&mut architecture_weights, &metrics)?;
396
397 self.evaluated_architectures.push(architecture.clone());
399
400 self.update_best_architecture(&architecture);
402
403 if self.should_early_stop() {
405 break;
406 }
407 }
408
409 Ok(())
410 }
411
412 fn run_gdas(&mut self) -> Result<()> {
414 println!("Running GDAS algorithm...");
415
416 for _epoch in 0..100 {
418 let architecture = self.sample_architecture_gdas()?;
419 let _metrics = self.evaluate_architecture(&architecture)?;
420
421 self.evaluated_architectures.push(architecture.clone());
422 self.update_best_architecture(&architecture);
423
424 if self.should_early_stop() {
425 break;
426 }
427 }
428
429 Ok(())
430 }
431
432 fn run_enas(&mut self) -> Result<()> {
434 println!("Running ENAS algorithm...");
435
436 let mut controller = ENASController::new();
438
439 for _epoch in 0..100 {
440 let architecture = controller.sample_architecture(&self.search_space)?;
442
443 let metrics = self.evaluate_architecture(&architecture)?;
445
446 controller.update_with_reward(&architecture, metrics.accuracy)?;
448
449 self.evaluated_architectures.push(architecture.clone());
450 self.update_best_architecture(&architecture);
451
452 if self.should_early_stop() {
453 break;
454 }
455 }
456
457 Ok(())
458 }
459
460 fn run_proxyless_nas(&mut self) -> Result<()> {
462 println!("Running ProxylessNAS algorithm...");
463
464 for _epoch in 0..100 {
466 let architecture = self.sample_architecture_proxyless()?;
467 let _metrics = self.evaluate_architecture(&architecture)?;
468
469 self.evaluated_architectures.push(architecture.clone());
470 self.update_best_architecture(&architecture);
471
472 if self.should_early_stop() {
473 break;
474 }
475 }
476
477 Ok(())
478 }
479
480 fn run_progressive_search(&mut self) -> Result<()> {
482 println!("Running Progressive search...");
483
484 let complexity_levels = vec![0.2, 0.4, 0.6, 0.8, 1.0];
486
487 for complexity in complexity_levels {
488 for _ in 0..20 {
489 let architecture = self.sample_architecture_with_complexity(complexity)?;
490 let _metrics = self.evaluate_architecture(&architecture)?;
491
492 self.evaluated_architectures.push(architecture.clone());
493 self.update_best_architecture(&architecture);
494 }
495 }
496
497 Ok(())
498 }
499
500 fn run_evolutionary_search(&mut self) -> Result<()> {
502 println!("Running Evolutionary search...");
503
504 let mut population = self.initialize_population(50)?;
506
507 for _generation in 0..100 {
508 for architecture in &population {
510 let _metrics = self.evaluate_architecture(architecture)?;
511 }
513
514 let parents = self.select_parents(&population)?;
516
517 let offspring = self.create_offspring(&parents)?;
519
520 population = self.update_population(population, offspring)?;
522
523 if let Some(best_in_generation) = self.get_best_from_population(&population) {
525 self.update_best_architecture(&best_in_generation);
526 }
527
528 if self.should_early_stop() {
529 break;
530 }
531 }
532
533 Ok(())
534 }
535
536 fn run_random_search(&mut self) -> Result<()> {
538 println!("Running Random search...");
539
540 for _ in 0..self.config.max_architectures {
541 let architecture = self.sample_random_architecture()?;
542 let _metrics = self.evaluate_architecture(&architecture)?;
543
544 self.evaluated_architectures.push(architecture.clone());
545 self.update_best_architecture(&architecture);
546
547 if self.should_early_stop() {
548 break;
549 }
550 }
551
552 Ok(())
553 }
554
555 fn initialize_architecture_weights(&self) -> Result<HashMap<String, f32>> {
557 let mut weights = HashMap::new();
558
559 for operation in &self.config.search_space.operations {
561 weights.insert(format!("{:?}", operation), 0.5);
562 }
563
564 Ok(weights)
565 }
566
567 fn sample_architecture_from_weights(
569 &self,
570 _weights: &HashMap<String, f32>,
571 ) -> Result<Architecture> {
572 let architecture = Architecture {
574 id: format!("arch_{}", uuid::Uuid::new_v4()),
575 encoding: vec![LayerSpec {
576 layer_type: LayerType::Transformer,
577 parameters: HashMap::new(),
578 dimensions: (512, 512),
579 }],
580 metrics: PerformanceMetrics {
581 accuracy: 0.0,
582 loss: 0.0,
583 inference_time: Duration::from_millis(0),
584 memory_usage: 0,
585 parameter_count: 0,
586 flops: 0,
587 },
588 hardware_metrics: HardwareMetrics {
589 gpu_utilization: 0.0,
590 memory_bandwidth: 0.0,
591 power_consumption: 0.0,
592 temperature: 0.0,
593 },
594 training_history: Vec::new(),
595 };
596
597 Ok(architecture)
598 }
599
600 fn update_architecture_weights(
602 &self,
603 weights: &mut HashMap<String, f32>,
604 metrics: &PerformanceMetrics,
605 ) -> Result<()> {
606 let learning_rate = 0.01;
608 for (_, weight) in weights.iter_mut() {
609 *weight += learning_rate * metrics.accuracy;
610 }
611 Ok(())
612 }
613
614 fn sample_architecture_gdas(&self) -> Result<Architecture> {
616 self.sample_random_architecture()
618 }
619
620 fn sample_architecture_with_complexity(&self, complexity: f32) -> Result<Architecture> {
622 let layer_count = (complexity * 48.0) as usize;
624 let mut encoding = Vec::new();
625
626 for _ in 0..layer_count {
627 encoding.push(LayerSpec {
628 layer_type: LayerType::Transformer,
629 parameters: HashMap::new(),
630 dimensions: (512, 512),
631 });
632 }
633
634 Ok(Architecture {
635 id: format!("arch_{}", uuid::Uuid::new_v4()),
636 encoding,
637 metrics: PerformanceMetrics {
638 accuracy: 0.0,
639 loss: 0.0,
640 inference_time: Duration::from_millis(0),
641 memory_usage: 0,
642 parameter_count: 0,
643 flops: 0,
644 },
645 hardware_metrics: HardwareMetrics {
646 gpu_utilization: 0.0,
647 memory_bandwidth: 0.0,
648 power_consumption: 0.0,
649 temperature: 0.0,
650 },
651 training_history: Vec::new(),
652 })
653 }
654
655 fn sample_architecture_proxyless(&self) -> Result<Architecture> {
657 self.sample_random_architecture()
659 }
660
661 fn sample_random_architecture(&self) -> Result<Architecture> {
663 let mut rng = thread_rng();
664
665 let layer_count = rng.random_range(
666 self.config.search_space.depth_range.0..=self.config.search_space.depth_range.1,
667 );
668 let mut encoding = Vec::new();
669
670 for _ in 0..layer_count {
671 encoding.push(LayerSpec {
672 layer_type: LayerType::Transformer,
673 parameters: HashMap::new(),
674 dimensions: (512, 512),
675 });
676 }
677
678 Ok(Architecture {
679 id: format!("arch_{}", uuid::Uuid::new_v4()),
680 encoding,
681 metrics: PerformanceMetrics {
682 accuracy: 0.0,
683 loss: 0.0,
684 inference_time: Duration::from_millis(0),
685 memory_usage: 0,
686 parameter_count: 0,
687 flops: 0,
688 },
689 hardware_metrics: HardwareMetrics {
690 gpu_utilization: 0.0,
691 memory_bandwidth: 0.0,
692 power_consumption: 0.0,
693 temperature: 0.0,
694 },
695 training_history: Vec::new(),
696 })
697 }
698
699 fn evaluate_architecture(
701 &mut self,
702 _architecture: &Architecture,
703 ) -> Result<PerformanceMetrics> {
704 let mut rng = thread_rng();
707
708 let metrics = PerformanceMetrics {
709 accuracy: rng.random_range(0.6..0.95),
710 loss: rng.random_range(0.1..2.0),
711 inference_time: Duration::from_millis(rng.random_range(10..200)),
712 memory_usage: rng.random_range(100_000_000..2_000_000_000),
713 parameter_count: rng.random_range(10_000_000..1_000_000_000),
714 flops: rng.random_range(100_000_000..10_000_000_000),
715 };
716
717 Ok(metrics)
718 }
719
720 fn update_best_architecture(&mut self, architecture: &Architecture) {
722 if let Some(ref current_best) = self.current_best {
723 if architecture.metrics.accuracy > current_best.metrics.accuracy {
724 self.current_best = Some(architecture.clone());
725 }
726 } else {
727 self.current_best = Some(architecture.clone());
728 }
729 }
730
731 fn should_early_stop(&self) -> bool {
733 if !self.config.early_stopping.enabled {
734 return false;
735 }
736
737 if self.evaluated_architectures.len() < self.config.early_stopping.patience {
738 return false;
739 }
740
741 let recent_best = self
743 .evaluated_architectures
744 .iter()
745 .rev()
746 .take(self.config.early_stopping.patience)
747 .max_by(|a, b| {
748 a.metrics
749 .accuracy
750 .partial_cmp(&b.metrics.accuracy)
751 .unwrap_or(std::cmp::Ordering::Equal)
752 });
753
754 if let Some(current_best) = &self.current_best {
755 if let Some(recent_best) = recent_best {
756 return recent_best.metrics.accuracy - current_best.metrics.accuracy
757 < self.config.early_stopping.min_improvement;
758 }
759 }
760
761 false
762 }
763
764 fn initialize_population(&self, size: usize) -> Result<Vec<Architecture>> {
766 let mut population = Vec::new();
767
768 for _ in 0..size {
769 population.push(self.sample_random_architecture()?);
770 }
771
772 Ok(population)
773 }
774
775 fn select_parents(&self, population: &[Architecture]) -> Result<Vec<Architecture>> {
777 let tournament_size = 5;
779 let mut parents = Vec::new();
780 let mut rng = thread_rng();
781
782 for _ in 0..population.len() / 2 {
783 let mut tournament = Vec::new();
784 for _ in 0..tournament_size {
785 let idx = rng.random_range(0..population.len());
786 tournament.push(&population[idx]);
787 }
788
789 let best = tournament
790 .iter()
791 .max_by(|a, b| {
792 a.metrics
793 .accuracy
794 .partial_cmp(&b.metrics.accuracy)
795 .unwrap_or(std::cmp::Ordering::Equal)
796 })
797 .ok_or_else(|| anyhow::anyhow!("Tournament selection failed: empty tournament"))?;
798
799 parents.push((*best).clone());
800 }
801
802 Ok(parents)
803 }
804
805 fn create_offspring(&self, parents: &[Architecture]) -> Result<Vec<Architecture>> {
807 let mut offspring = Vec::new();
808
809 for i in 0..parents.len() {
810 let parent1 = &parents[i];
811 let parent2 = &parents[(i + 1) % parents.len()];
812
813 let mut child_encoding = Vec::new();
815 let min_len = std::cmp::min(parent1.encoding.len(), parent2.encoding.len());
816
817 for j in 0..min_len {
818 if j % 2 == 0 {
819 child_encoding.push(parent1.encoding[j].clone());
820 } else {
821 child_encoding.push(parent2.encoding[j].clone());
822 }
823 }
824
825 let child = Architecture {
826 id: format!("child_{}", uuid::Uuid::new_v4()),
827 encoding: child_encoding,
828 metrics: PerformanceMetrics {
829 accuracy: 0.0,
830 loss: 0.0,
831 inference_time: Duration::from_millis(0),
832 memory_usage: 0,
833 parameter_count: 0,
834 flops: 0,
835 },
836 hardware_metrics: HardwareMetrics {
837 gpu_utilization: 0.0,
838 memory_bandwidth: 0.0,
839 power_consumption: 0.0,
840 temperature: 0.0,
841 },
842 training_history: Vec::new(),
843 };
844
845 offspring.push(child);
846 }
847
848 Ok(offspring)
849 }
850
851 fn update_population(
853 &self,
854 population: Vec<Architecture>,
855 offspring: Vec<Architecture>,
856 ) -> Result<Vec<Architecture>> {
857 let mut combined = population;
858 combined.extend(offspring);
859
860 combined.sort_by(|a, b| {
862 b.metrics
863 .accuracy
864 .partial_cmp(&a.metrics.accuracy)
865 .unwrap_or(std::cmp::Ordering::Equal)
866 });
867 combined.truncate(50); Ok(combined)
870 }
871
872 fn get_best_from_population(&self, population: &[Architecture]) -> Option<Architecture> {
874 population
875 .iter()
876 .max_by(|a, b| {
877 a.metrics
878 .accuracy
879 .partial_cmp(&b.metrics.accuracy)
880 .unwrap_or(std::cmp::Ordering::Equal)
881 })
882 .cloned()
883 }
884
885 pub fn get_search_stats(&self) -> SearchStats {
887 SearchStats {
888 total_architectures_evaluated: self.evaluated_architectures.len(),
889 best_accuracy: self.current_best.as_ref().map(|a| a.metrics.accuracy).unwrap_or(0.0),
890 search_time: self.search_history.iter().map(|e| e.duration).sum::<Duration>(),
891 algorithm_used: self.config.algorithm.clone(),
892 }
893 }
894}
895
896#[allow(dead_code)]
898pub struct SearchSpace {
899 #[allow(dead_code)]
900 operations: Vec<Operation>,
901 depth_range: (usize, usize),
902 width_range: (f32, f32),
903}
904
905impl SearchSpace {
906 pub fn new(config: &SearchSpaceConfig) -> Self {
907 Self {
908 operations: config.operations.clone(),
909 depth_range: config.depth_range,
910 width_range: config.width_range,
911 }
912 }
913}
914
915pub struct PerformancePredictor {
917 #[allow(dead_code)]
918 trained: bool,
919}
920
921impl Default for PerformancePredictor {
922 fn default() -> Self {
923 Self::new()
924 }
925}
926
927impl PerformancePredictor {
928 pub fn new() -> Self {
929 Self { trained: false }
930 }
931
932 pub fn predict(&self, architecture: &Architecture) -> Result<PerformanceMetrics> {
933 Ok(architecture.metrics.clone())
935 }
936}
937
938pub struct ArchitectureOptimizer {
940 #[allow(dead_code)]
941 optimization_active: bool,
942}
943
944impl Default for ArchitectureOptimizer {
945 fn default() -> Self {
946 Self::new()
947 }
948}
949
950impl ArchitectureOptimizer {
951 pub fn new() -> Self {
952 Self {
953 optimization_active: false,
954 }
955 }
956
957 pub fn optimize(&mut self, architecture: &Architecture) -> Result<Architecture> {
958 Ok(architecture.clone())
960 }
961}
962
963pub struct ENASController {
965 #[allow(dead_code)]
966 trained: bool,
967}
968
969impl Default for ENASController {
970 fn default() -> Self {
971 Self::new()
972 }
973}
974
975impl ENASController {
976 pub fn new() -> Self {
977 Self { trained: false }
978 }
979
980 pub fn sample_architecture(&self, _search_space: &SearchSpace) -> Result<Architecture> {
981 Ok(Architecture {
983 id: format!("enas_{}", uuid::Uuid::new_v4()),
984 encoding: vec![LayerSpec {
985 layer_type: LayerType::Transformer,
986 parameters: HashMap::new(),
987 dimensions: (512, 512),
988 }],
989 metrics: PerformanceMetrics {
990 accuracy: 0.0,
991 loss: 0.0,
992 inference_time: Duration::from_millis(0),
993 memory_usage: 0,
994 parameter_count: 0,
995 flops: 0,
996 },
997 hardware_metrics: HardwareMetrics {
998 gpu_utilization: 0.0,
999 memory_bandwidth: 0.0,
1000 power_consumption: 0.0,
1001 temperature: 0.0,
1002 },
1003 training_history: Vec::new(),
1004 })
1005 }
1006
1007 pub fn update_with_reward(&mut self, _architecture: &Architecture, _reward: f32) -> Result<()> {
1008 Ok(())
1010 }
1011}
1012
1013#[derive(Debug, Clone)]
1015pub struct SearchEvent {
1016 pub timestamp: Instant,
1017 pub event_type: SearchEventType,
1018 pub duration: Duration,
1019 pub architectures_evaluated: usize,
1020}
1021
1022#[derive(Debug, Clone)]
1024pub enum SearchEventType {
1025 SearchStarted,
1026 SearchCompleted,
1027 ArchitectureEvaluated,
1028 BestArchitectureUpdated,
1029 EarlyStoppingStopped,
1030}
1031
1032#[derive(Debug, Clone, Serialize, Deserialize)]
1034pub struct SearchStats {
1035 pub total_architectures_evaluated: usize,
1036 pub best_accuracy: f32,
1037 pub search_time: Duration,
1038 pub algorithm_used: NASAlgorithm,
1039}
1040
1041#[cfg(test)]
1042mod tests {
1043 use super::*;
1044
1045 #[test]
1046 fn test_nas_controller_creation() {
1047 let config = NASConfig::default();
1048 let controller = NASController::new(config);
1049
1050 assert_eq!(controller.evaluated_architectures.len(), 0);
1051 assert!(controller.current_best.is_none());
1052 }
1053
1054 #[test]
1055 fn test_random_architecture_sampling() {
1056 let config = NASConfig::default();
1057 let controller = NASController::new(config);
1058
1059 let architecture =
1060 controller.sample_random_architecture().expect("operation failed in test");
1061 assert!(!architecture.id.is_empty());
1062 assert!(!architecture.encoding.is_empty());
1063 }
1064
1065 #[test]
1066 fn test_architecture_evaluation() {
1067 let config = NASConfig::default();
1068 let mut controller = NASController::new(config);
1069
1070 let architecture =
1071 controller.sample_random_architecture().expect("operation failed in test");
1072 let metrics = controller
1073 .evaluate_architecture(&architecture)
1074 .expect("operation failed in test");
1075
1076 assert!(metrics.accuracy >= 0.0 && metrics.accuracy <= 1.0);
1077 assert!(metrics.loss >= 0.0);
1078 }
1079
1080 #[test]
1081 fn test_early_stopping() {
1082 let config = NASConfig {
1083 early_stopping: EarlyStoppingConfig {
1084 enabled: true,
1085 patience: 5,
1086 min_improvement: 0.1,
1087 },
1088 ..Default::default()
1089 };
1090 let controller = NASController::new(config);
1091
1092 assert!(!controller.should_early_stop()); }
1094
1095 #[test]
1096 fn test_population_initialization() {
1097 let config = NASConfig::default();
1098 let controller = NASController::new(config);
1099
1100 let population = controller.initialize_population(10).expect("operation failed in test");
1101 assert_eq!(population.len(), 10);
1102
1103 for arch in &population {
1104 assert!(!arch.id.is_empty());
1105 }
1106 }
1107
1108 #[test]
1109 fn test_search_space_creation() {
1110 let config = SearchSpaceConfig::default();
1111 let search_space = SearchSpace::new(&config);
1112
1113 assert!(!search_space.operations.is_empty());
1114 assert!(search_space.depth_range.0 <= search_space.depth_range.1);
1115 }
1116}