1use super::{
8 ActivationType, LearnedOptimizationConfig, LearnedOptimizer, MetaOptimizerState,
9 OptimizationProblem, TrainingTask,
10};
11use crate::error::OptimizeResult;
12use crate::result::OptimizeResults;
13use scirs2_core::ndarray::{Array1, Array2, Array3, ArrayView1};
14use scirs2_core::random::Rng;
15use statrs::statistics::Statistics;
16use std::collections::HashMap;
17
18#[derive(Debug, Clone)]
20pub struct AdaptiveNASSystem {
21 config: LearnedOptimizationConfig,
23 architecture_population: Vec<OptimizationArchitecture>,
25 performance_history: HashMap<ArchitectureId, Vec<f64>>,
27 controller: ArchitectureController,
29 meta_state: MetaOptimizerState,
31 architecture_cache: HashMap<String, OptimizationArchitecture>,
33 search_stats: NASSearchStats,
35 generation: usize,
37}
38
39type ArchitectureId = String;
41
42#[derive(Debug, Clone)]
44pub struct OptimizationArchitecture {
45 pub id: ArchitectureId,
47 pub layers: Vec<LayerConfig>,
49 pub connections: Vec<Connection>,
51 pub activations: Vec<ActivationType>,
53 pub skip_connections: Vec<SkipConnection>,
55 pub optimizer_components: Vec<OptimizerComponent>,
57 pub complexity: f64,
59 pub performance_metrics: ArchitectureMetrics,
61}
62
63#[derive(Debug, Clone)]
65pub struct LayerConfig {
66 pub layer_type: LayerType,
68 pub units: usize,
70 pub dropout: f64,
72 pub normalization: NormalizationType,
74 pub parameters: HashMap<String, f64>,
76}
77
78#[derive(Debug, Clone)]
80pub enum LayerType {
81 Dense,
83 Convolution { kernel_size: usize, stride: usize },
85 Attention { num_heads: usize },
87 LSTM { hidden_size: usize },
89 GRU { hidden_size: usize },
91 Transformer { num_heads: usize, ff_dim: usize },
93 GraphNN { aggregation: String },
95 Memory { memory_size: usize },
97}
98
99#[derive(Debug, Clone)]
101pub enum NormalizationType {
102 None,
103 BatchNorm,
104 LayerNorm,
105 GroupNorm { groups: usize },
106 InstanceNorm,
107}
108
109#[derive(Debug, Clone)]
111pub struct Connection {
112 pub from: usize,
114 pub to: usize,
116 pub weight: f64,
118 pub connection_type: ConnectionType,
120}
121
122#[derive(Debug, Clone)]
124pub enum ConnectionType {
125 Forward,
127 Residual,
129 Dense,
131 Attention,
133}
134
135#[derive(Debug, Clone)]
137pub struct SkipConnection {
138 pub source: usize,
140 pub target: usize,
142 pub skip_type: SkipType,
144}
145
146#[derive(Debug, Clone)]
148pub enum SkipType {
149 Add,
151 Concat,
153 Gated { gate_size: usize },
155 Highway,
157}
158
159#[derive(Debug, Clone)]
161pub enum OptimizerComponent {
162 Momentum { decay: f64 },
164 AdaptiveLR {
166 adaptation_rate: f64,
167 min_lr: f64,
168 max_lr: f64,
169 },
170 SecondOrder {
172 hessian_approximation: HessianApprox,
173 regularization: f64,
174 },
175 TrustRegion {
177 initial_radius: f64,
178 max_radius: f64,
179 shrink_factor: f64,
180 expand_factor: f64,
181 },
182 LineSearch {
184 method: LineSearchMethod,
185 max_nit: usize,
186 },
187 Regularization {
189 l1_weight: f64,
190 l2_weight: f64,
191 elastic_net_ratio: f64,
192 },
193}
194
195#[derive(Debug, Clone)]
197pub enum HessianApprox {
198 BFGS,
199 LBFGS { memory_size: usize },
200 SR1,
201 DFP,
202 DiagonalApprox,
203}
204
205#[derive(Debug, Clone)]
207pub enum LineSearchMethod {
208 Backtracking,
209 StrongWolfe,
210 MoreThuente,
211 Armijo,
212 Exact,
213}
214
215#[derive(Debug, Clone)]
217pub struct ArchitectureMetrics {
218 pub convergence_rate: f64,
220 pub success_rate: f64,
222 pub avg_evaluations: f64,
224 pub robustness: f64,
226 pub transfer_score: f64,
228 pub efficiency: f64,
230}
231
232impl Default for ArchitectureMetrics {
233 fn default() -> Self {
234 Self {
235 convergence_rate: 0.0,
236 success_rate: 0.0,
237 avg_evaluations: 0.0,
238 robustness: 0.0,
239 transfer_score: 0.0,
240 efficiency: 0.0,
241 }
242 }
243}
244
245#[derive(Debug, Clone)]
247pub struct ArchitectureController {
248 lstm_weights: Array3<f64>,
250 embedding_layer: Array2<f64>,
252 output_layer: Array2<f64>,
254 controller_state: Array1<f64>,
256 vocabulary: ArchitectureVocabulary,
258}
259
260#[derive(Debug, Clone)]
262pub struct ArchitectureVocabulary {
263 pub layer_types: HashMap<String, usize>,
265 pub activations: HashMap<String, usize>,
267 pub components: HashMap<String, usize>,
269 pub vocab_size: usize,
271}
272
273#[derive(Debug, Clone)]
275pub struct NASSearchStats {
276 pub architectures_evaluated: usize,
278 pub best_performance: f64,
280 pub search_efficiency: f64,
282 pub population_diversity: f64,
284 pub convergence_indicators: Vec<f64>,
286}
287
288impl Default for NASSearchStats {
289 fn default() -> Self {
290 Self {
291 architectures_evaluated: 0,
292 best_performance: f64::NEG_INFINITY,
293 search_efficiency: 0.0,
294 population_diversity: 1.0,
295 convergence_indicators: Vec::new(),
296 }
297 }
298}
299
300impl AdaptiveNASSystem {
301 pub fn new(config: LearnedOptimizationConfig) -> Self {
303 let vocabulary = ArchitectureVocabulary::new();
304 let controller = ArchitectureController::new(&vocabulary, config.hidden_size);
305 let hidden_size = config.hidden_size;
306
307 Self {
308 config,
309 architecture_population: Vec::new(),
310 performance_history: HashMap::new(),
311 controller,
312 meta_state: MetaOptimizerState {
313 meta_params: Array1::zeros(100),
314 network_weights: Array2::zeros((hidden_size, hidden_size)),
315 performance_history: Vec::new(),
316 adaptation_stats: super::AdaptationStatistics::default(),
317 episode: 0,
318 },
319 architecture_cache: HashMap::new(),
320 search_stats: NASSearchStats::default(),
321 generation: 0,
322 }
323 }
324
325 pub fn search_architectures(
327 &mut self,
328 training_problems: &[OptimizationProblem],
329 ) -> OptimizeResult<Vec<OptimizationArchitecture>> {
330 if self.architecture_population.is_empty() {
332 self.initialize_population()?;
333 }
334
335 for generation in 0..self.config.meta_training_episodes {
336 self.generation = generation;
337
338 self.evaluate_population(training_problems)?;
340
341 self.update_controller()?;
343
344 let new_architectures = self.generate_new_architectures()?;
346
347 self.select_next_generation(new_architectures)?;
349
350 self.update_search_stats()?;
352
353 if self.check_convergence() {
355 break;
356 }
357 }
358
359 Ok(self.get_best_architectures())
360 }
361
362 fn initialize_population(&mut self) -> OptimizeResult<()> {
364 for _ in 0..self.config.batch_size {
365 let architecture = self.generate_random_architecture()?;
366 self.architecture_population.push(architecture);
367 }
368 Ok(())
369 }
370
371 fn generate_random_architecture(&self) -> OptimizeResult<OptimizationArchitecture> {
373 let num_layers = 2 + (scirs2_core::random::rng().random_range(0..8)); let mut layers = Vec::new();
375 let mut connections = Vec::new();
376 let mut activations = Vec::new();
377 let mut optimizer_components = Vec::new();
378
379 for i in 0..num_layers {
381 let layer_type = self.sample_layer_type();
382 let units = 16 + (scirs2_core::random::rng().random_range(0..256)); layers.push(LayerConfig {
385 layer_type,
386 units,
387 dropout: scirs2_core::random::rng().random_range(0.0..0.5),
388 normalization: self.sample_normalization(),
389 parameters: HashMap::new(),
390 });
391
392 activations.push(self.sample_activation());
393
394 if i > 0 {
396 connections.push(Connection {
397 from: i - 1,
398 to: i,
399 weight: 1.0,
400 connection_type: ConnectionType::Forward,
401 });
402
403 if i > 1 && scirs2_core::random::rng().random_range(0.0..1.0) < 0.3 {
405 let skip_source = scirs2_core::random::rng().random_range(0..i);
406 connections.push(Connection {
407 from: skip_source,
408 to: i,
409 weight: 0.5,
410 connection_type: ConnectionType::Residual,
411 });
412 }
413 }
414 }
415
416 for _ in 0..(1 + scirs2_core::random::rng().random_range(0..4)) {
418 optimizer_components.push(self.sample_optimizer_component());
419 }
420
421 let id = format!(
422 "arch_{}",
423 scirs2_core::random::rng().random_range(0..u64::MAX)
424 );
425
426 Ok(OptimizationArchitecture {
427 id,
428 layers,
429 connections,
430 activations,
431 skip_connections: Vec::new(),
432 optimizer_components,
433 complexity: 0.0,
434 performance_metrics: ArchitectureMetrics::default(),
435 })
436 }
437
438 fn sample_layer_type(&self) -> LayerType {
439 match scirs2_core::random::rng().random_range(0..8) {
440 0 => LayerType::Dense,
441 1 => LayerType::Attention {
442 num_heads: 2 + scirs2_core::random::rng().random_range(0..6),
443 },
444 2 => LayerType::LSTM {
445 hidden_size: 32 + scirs2_core::random::rng().random_range(0..128),
446 },
447 3 => LayerType::GRU {
448 hidden_size: 32 + scirs2_core::random::rng().random_range(0..128),
449 },
450 4 => LayerType::Transformer {
451 num_heads: 2 + scirs2_core::random::rng().random_range(0..6),
452 ff_dim: 64 + scirs2_core::random::rng().random_range(0..256),
453 },
454 5 => LayerType::Memory {
455 memory_size: 16 + scirs2_core::random::rng().random_range(0..64),
456 },
457 6 => LayerType::Convolution {
458 kernel_size: 1 + scirs2_core::random::rng().random_range(0..5),
459 stride: 1 + scirs2_core::random::rng().random_range(0..3),
460 },
461 _ => LayerType::GraphNN {
462 aggregation: "mean".to_string(),
463 },
464 }
465 }
466
467 fn sample_normalization(&self) -> NormalizationType {
468 match scirs2_core::random::rng().random_range(0..5) {
469 0 => NormalizationType::None,
470 1 => NormalizationType::BatchNorm,
471 2 => NormalizationType::LayerNorm,
472 3 => NormalizationType::GroupNorm {
473 groups: 2 + scirs2_core::random::rng().random_range(0..6),
474 },
475 _ => NormalizationType::InstanceNorm,
476 }
477 }
478
479 fn sample_activation(&self) -> ActivationType {
480 match scirs2_core::random::rng().random_range(0..5) {
481 0 => ActivationType::ReLU,
482 1 => ActivationType::GELU,
483 2 => ActivationType::Swish,
484 3 => ActivationType::Tanh,
485 _ => ActivationType::LeakyReLU,
486 }
487 }
488
489 fn sample_optimizer_component(&self) -> OptimizerComponent {
490 match scirs2_core::random::rng().random_range(0..6) {
491 0 => OptimizerComponent::Momentum {
492 decay: 0.8 + scirs2_core::random::rng().random_range(0.0..0.19),
493 },
494 1 => OptimizerComponent::AdaptiveLR {
495 adaptation_rate: 0.001 + scirs2_core::random::rng().random_range(0.0..0.009),
496 min_lr: 1e-8,
497 max_lr: 1.0,
498 },
499 2 => OptimizerComponent::SecondOrder {
500 hessian_approximation: HessianApprox::LBFGS {
501 memory_size: 5 + scirs2_core::random::rng().random_range(0..15),
502 },
503 regularization: 1e-6 + scirs2_core::random::rng().random_range(0.0..1e-3),
504 },
505 3 => OptimizerComponent::TrustRegion {
506 initial_radius: 0.1 + scirs2_core::random::rng().random_range(0.0..0.9),
507 max_radius: 10.0,
508 shrink_factor: 0.25,
509 expand_factor: 2.0,
510 },
511 4 => OptimizerComponent::LineSearch {
512 method: LineSearchMethod::StrongWolfe,
513 max_nit: 10 + scirs2_core::random::rng().random_range(0..20),
514 },
515 _ => OptimizerComponent::Regularization {
516 l1_weight: scirs2_core::random::rng().random_range(0.0..0.01),
517 l2_weight: scirs2_core::random::rng().random_range(0.0..0.01),
518 elastic_net_ratio: scirs2_core::random::rng().random_range(0.0..1.0),
519 },
520 }
521 }
522
523 fn evaluate_population(
525 &mut self,
526 training_problems: &[OptimizationProblem],
527 ) -> OptimizeResult<()> {
528 let scores: Vec<_> = self
530 .architecture_population
531 .iter()
532 .map(|architecture| {
533 let mut total_score = 0.0;
534 let mut num_evaluated = 0;
535
536 for problem in training_problems.iter().take(5) {
537 if let Ok(score) = self.evaluate_architecture_on_problem(architecture, problem)
539 {
540 total_score += score;
541 num_evaluated += 1;
542 }
543 }
544
545 if num_evaluated > 0 {
546 Some(total_score / num_evaluated as f64)
547 } else {
548 None
549 }
550 })
551 .collect();
552
553 for (architecture, score) in self.architecture_population.iter_mut().zip(scores.iter()) {
555 if let Some(avg_score) = score {
556 architecture.performance_metrics.convergence_rate = *avg_score;
557
558 self.performance_history
560 .entry(architecture.id.clone())
561 .or_default()
562 .push(*avg_score);
563 }
564 }
565
566 Ok(())
567 }
568
569 fn evaluate_architecture_on_problem(
571 &self,
572 architecture: &OptimizationArchitecture,
573 problem: &OptimizationProblem,
574 ) -> OptimizeResult<f64> {
575 let complexity_penalty = architecture.complexity * 0.01;
577 let num_components = architecture.optimizer_components.len() as f64;
578 let num_layers = architecture.layers.len() as f64;
579
580 let base_score = (num_components * 0.1 + num_layers * 0.05).min(1.0);
582 let final_score = base_score - complexity_penalty;
583
584 Ok(final_score.max(0.0))
585 }
586
587 fn update_controller(&mut self) -> OptimizeResult<()> {
589 let mut rewards = Vec::new();
591 for arch in &self.architecture_population {
592 rewards.push(arch.performance_metrics.convergence_rate);
593 }
594
595 if rewards.is_empty() {
596 return Ok(());
597 }
598
599 let baseline = rewards.iter().sum::<f64>() / rewards.len() as f64;
601
602 for (i, &reward) in rewards.iter().enumerate() {
603 let advantage = reward - baseline;
604
605 let lstm_len = self.controller.lstm_weights.len();
607 if i < lstm_len {
608 let shape = self.controller.lstm_weights.shape();
609 let dims = (shape[0], shape[1], shape[2]);
610 for j in 0..dims.1 {
611 for k in 0..dims.2 {
612 let learning_rate = self.config.meta_learning_rate;
613 let idx = (i % lstm_len, j, k);
614 self.controller.lstm_weights[idx] += learning_rate * advantage * 0.01;
615 }
616 }
617 }
618 }
619
620 Ok(())
621 }
622
623 fn generate_new_architectures(&mut self) -> OptimizeResult<Vec<OptimizationArchitecture>> {
625 let mut new_architectures = Vec::new();
626
627 for _ in 0..self.config.batch_size / 2 {
628 let architecture = self.controller_generate_architecture()?;
630 new_architectures.push(architecture);
631
632 if !self.architecture_population.is_empty() {
634 let best_idx = self.get_best_architecture_index();
635 let mutated = self.mutate_architecture(&self.architecture_population[best_idx])?;
636 new_architectures.push(mutated);
637 }
638 }
639
640 Ok(new_architectures)
641 }
642
643 fn controller_generate_architecture(&mut self) -> OptimizeResult<OptimizationArchitecture> {
645 let mut architecture = self.generate_random_architecture()?;
649
650 let controller_influence = self.controller.controller_state.view().mean();
652
653 if controller_influence > 0.5 {
655 if architecture.layers.len() < 10 {
657 architecture.layers.push(LayerConfig {
658 layer_type: LayerType::Dense,
659 units: 64,
660 dropout: 0.1,
661 normalization: NormalizationType::LayerNorm,
662 parameters: HashMap::new(),
663 });
664 }
665 } else {
666 if architecture.layers.len() > 2 {
668 architecture.layers.pop();
669 }
670 }
671
672 Ok(architecture)
673 }
674
675 fn mutate_architecture(
677 &self,
678 base_arch: &OptimizationArchitecture,
679 ) -> OptimizeResult<OptimizationArchitecture> {
680 let mut mutated = base_arch.clone();
681 mutated.id = format!(
682 "mutated_{}",
683 scirs2_core::random::rng().random_range(0..u64::MAX)
684 );
685
686 if scirs2_core::random::rng().random_range(0.0..1.0) < 0.3 {
688 if scirs2_core::random::rng().random_range(0.0..1.0) < 0.5 && mutated.layers.len() < 12
690 {
691 mutated.layers.push(LayerConfig {
692 layer_type: self.sample_layer_type(),
693 units: 32 + scirs2_core::random::rng().random_range(0..128),
694 dropout: scirs2_core::random::rng().random_range(0.0..0.5),
695 normalization: self.sample_normalization(),
696 parameters: HashMap::new(),
697 });
698 } else if mutated.layers.len() > 2 {
699 mutated.layers.pop();
700 }
701 }
702
703 for activation in &mut mutated.activations {
705 if scirs2_core::random::rng().random_range(0.0..1.0) < 0.2 {
706 *activation = self.sample_activation();
707 }
708 }
709
710 if scirs2_core::random::rng().random_range(0.0..1.0) < 0.4 {
712 if scirs2_core::random::rng().random_range(0.0..1.0) < 0.5
713 && mutated.optimizer_components.len() < 6
714 {
715 mutated
716 .optimizer_components
717 .push(self.sample_optimizer_component());
718 } else if !mutated.optimizer_components.is_empty() {
719 let idx =
720 scirs2_core::random::rng().random_range(0..mutated.optimizer_components.len());
721 mutated.optimizer_components.remove(idx);
722 }
723 }
724
725 Ok(mutated)
726 }
727
728 fn select_next_generation(
730 &mut self,
731 mut new_architectures: Vec<OptimizationArchitecture>,
732 ) -> OptimizeResult<()> {
733 self.architecture_population.append(&mut new_architectures);
735
736 self.architecture_population.sort_by(|a, b| {
738 b.performance_metrics
739 .convergence_rate
740 .partial_cmp(&a.performance_metrics.convergence_rate)
741 .unwrap_or(std::cmp::Ordering::Equal)
742 });
743
744 self.architecture_population
746 .truncate(self.config.batch_size);
747
748 Ok(())
749 }
750
751 fn update_search_stats(&mut self) -> OptimizeResult<()> {
753 self.search_stats.architectures_evaluated += self.architecture_population.len();
754
755 if let Some(best_arch) = self.architecture_population.first() {
756 let best_performance = best_arch.performance_metrics.convergence_rate;
757 if best_performance > self.search_stats.best_performance {
758 self.search_stats.best_performance = best_performance;
759 }
760 }
761
762 let performances: Vec<f64> = self
764 .architecture_population
765 .iter()
766 .map(|a| a.performance_metrics.convergence_rate)
767 .collect();
768
769 if performances.len() > 1 {
770 let mean = performances.iter().sum::<f64>() / performances.len() as f64;
771 let variance = performances
772 .iter()
773 .map(|&p| (p - mean).powi(2))
774 .sum::<f64>()
775 / performances.len() as f64;
776 self.search_stats.population_diversity = variance.sqrt();
777 }
778
779 self.search_stats
780 .convergence_indicators
781 .push(self.search_stats.best_performance);
782
783 Ok(())
784 }
785
786 fn check_convergence(&self) -> bool {
788 if self.search_stats.convergence_indicators.len() < 10 {
789 return false;
790 }
791
792 let recent_improvements: Vec<f64> = self
794 .search_stats
795 .convergence_indicators
796 .windows(2)
797 .map(|w| w[1] - w[0])
798 .collect();
799
800 let avg_improvement =
801 recent_improvements.iter().sum::<f64>() / recent_improvements.len() as f64;
802 avg_improvement < 1e-6
803 }
804
805 fn get_best_architectures(&self) -> Vec<OptimizationArchitecture> {
807 self.architecture_population.clone()
808 }
809
810 fn get_best_architecture_index(&self) -> usize {
811 self.architecture_population
812 .iter()
813 .enumerate()
814 .max_by(|(_, a), (_, b)| {
815 a.performance_metrics
816 .convergence_rate
817 .partial_cmp(&b.performance_metrics.convergence_rate)
818 .unwrap_or(std::cmp::Ordering::Equal)
819 })
820 .map(|(i, _)| i)
821 .unwrap_or(0)
822 }
823
824 pub fn get_search_stats(&self) -> &NASSearchStats {
826 &self.search_stats
827 }
828
829 pub fn cache_architecture_for_problem(
831 &mut self,
832 problem_class: String,
833 architecture: OptimizationArchitecture,
834 ) {
835 self.architecture_cache.insert(problem_class, architecture);
836 }
837
838 pub fn get_cached_architecture(
840 &self,
841 problem_class: &str,
842 ) -> Option<&OptimizationArchitecture> {
843 self.architecture_cache.get(problem_class)
844 }
845}
846
847impl ArchitectureController {
848 pub fn new(vocabulary: &ArchitectureVocabulary, hidden_size: usize) -> Self {
850 Self {
851 lstm_weights: Array3::from_shape_fn((4, hidden_size, hidden_size), |_| {
852 (scirs2_core::random::rng().random_range(0.0..1.0) - 0.5) * 0.1
853 }),
854 embedding_layer: Array2::from_shape_fn((hidden_size, vocabulary.vocab_size), |_| {
855 (scirs2_core::random::rng().random_range(0.0..1.0) - 0.5) * 0.1
856 }),
857 output_layer: Array2::from_shape_fn((vocabulary.vocab_size, hidden_size), |_| {
858 (scirs2_core::random::rng().random_range(0.0..1.0) - 0.5) * 0.1
859 }),
860 controller_state: Array1::zeros(hidden_size),
861 vocabulary: vocabulary.clone(),
862 }
863 }
864}
865
866impl Default for ArchitectureVocabulary {
867 fn default() -> Self {
868 Self::new()
869 }
870}
871
872impl ArchitectureVocabulary {
873 pub fn new() -> Self {
875 let mut layer_types = HashMap::new();
876 layer_types.insert("dense".to_string(), 0);
877 layer_types.insert("conv".to_string(), 1);
878 layer_types.insert("attention".to_string(), 2);
879 layer_types.insert("lstm".to_string(), 3);
880 layer_types.insert("gru".to_string(), 4);
881 layer_types.insert("transformer".to_string(), 5);
882 layer_types.insert("graph".to_string(), 6);
883 layer_types.insert("memory".to_string(), 7);
884
885 let mut activations = HashMap::new();
886 activations.insert("relu".to_string(), 8);
887 activations.insert("gelu".to_string(), 9);
888 activations.insert("swish".to_string(), 10);
889 activations.insert("tanh".to_string(), 11);
890 activations.insert("leaky_relu".to_string(), 12);
891
892 let mut components = HashMap::new();
893 components.insert("momentum".to_string(), 13);
894 components.insert("adaptive_lr".to_string(), 14);
895 components.insert("second_order".to_string(), 15);
896 components.insert("trust_region".to_string(), 16);
897 components.insert("line_search".to_string(), 17);
898 components.insert("regularization".to_string(), 18);
899
900 Self {
901 layer_types,
902 activations,
903 components,
904 vocab_size: 19,
905 }
906 }
907}
908
909impl LearnedOptimizer for AdaptiveNASSystem {
910 fn meta_train(&mut self, training_tasks: &[TrainingTask]) -> OptimizeResult<()> {
911 let problems: Vec<OptimizationProblem> = training_tasks
912 .iter()
913 .map(|task| task.problem.clone())
914 .collect();
915
916 self.search_architectures(&problems)?;
917 Ok(())
918 }
919
920 fn adapt_to_problem(
921 &mut self,
922 problem: &OptimizationProblem,
923 initial_params: &ArrayView1<f64>,
924 ) -> OptimizeResult<()> {
925 if let Some(cached_arch) = self.get_cached_architecture(&problem.problem_class) {
927 return Ok(());
929 }
930
931 let specialized_arch = self.generate_random_architecture()?;
933 self.cache_architecture_for_problem(problem.problem_class.clone(), specialized_arch);
934
935 Ok(())
936 }
937
938 fn optimize<F>(
939 &mut self,
940 objective: F,
941 initial_params: &ArrayView1<f64>,
942 ) -> OptimizeResult<OptimizeResults<f64>>
943 where
944 F: Fn(&ArrayView1<f64>) -> f64,
945 {
946 if self.architecture_population.is_empty() {
948 self.initialize_population()?;
949 }
950
951 let best_idx = self.get_best_architecture_index();
952 let best_arch = &self.architecture_population[best_idx];
953
954 let mut current_params = initial_params.to_owned();
956 let mut best_value = objective(initial_params);
957 let mut iterations = 0;
958
959 for iter in 0..1000 {
960 iterations = iter;
961
962 let step_size = self.compute_step_size(best_arch, iter);
964 let direction = self.compute_search_direction(&objective, ¤t_params, best_arch);
965
966 for i in 0..current_params.len() {
968 current_params[i] -= step_size * direction[i];
969 }
970
971 let current_value = objective(¤t_params.view());
972
973 if current_value < best_value {
974 best_value = current_value;
975 }
976
977 if step_size < 1e-8 {
979 break;
980 }
981 }
982
983 Ok(OptimizeResults::<f64> {
984 x: current_params,
985 fun: best_value,
986 success: true,
987 nit: iterations,
988 message: format!(
989 "NAS optimization completed using architecture: {}",
990 best_arch.id
991 ),
992 jac: None,
993 hess: None,
994 constr: None,
995 nfev: iterations * best_arch.layers.len(), njev: 0,
997 nhev: 0,
998 maxcv: 0,
999 status: 0,
1000 })
1001 }
1002
1003 fn get_state(&self) -> &MetaOptimizerState {
1004 &self.meta_state
1005 }
1006
1007 fn reset(&mut self) {
1008 self.architecture_population.clear();
1009 self.performance_history.clear();
1010 self.search_stats = NASSearchStats::default();
1011 self.generation = 0;
1012 }
1013}
1014
1015impl AdaptiveNASSystem {
1016 fn compute_step_size(&self, architecture: &OptimizationArchitecture, iteration: usize) -> f64 {
1017 let mut step_size = 0.01;
1018
1019 for component in &architecture.optimizer_components {
1021 match component {
1022 OptimizerComponent::AdaptiveLR {
1023 adaptation_rate,
1024 min_lr,
1025 max_lr,
1026 } => {
1027 step_size *= 1.0 + adaptation_rate * (iteration as f64).cos();
1028 step_size = step_size.max(*min_lr).min(*max_lr);
1029 }
1030 OptimizerComponent::TrustRegion { initial_radius, .. } => {
1031 step_size = step_size.min(*initial_radius);
1032 }
1033 _ => {}
1034 }
1035 }
1036
1037 step_size / (1.0 + iteration as f64 * 0.001)
1038 }
1039
1040 fn compute_search_direction<F>(
1041 &self,
1042 objective: &F,
1043 params: &Array1<f64>,
1044 architecture: &OptimizationArchitecture,
1045 ) -> Array1<f64>
1046 where
1047 F: Fn(&ArrayView1<f64>) -> f64,
1048 {
1049 let mut direction = Array1::zeros(params.len());
1050
1051 let h = 1e-6;
1053 let f0 = objective(¶ms.view());
1054
1055 for i in 0..params.len() {
1056 let mut params_plus = params.clone();
1057 params_plus[i] += h;
1058 let f_plus = objective(¶ms_plus.view());
1059 direction[i] = (f_plus - f0) / h;
1060 }
1061
1062 for component in &architecture.optimizer_components {
1064 match component {
1065 OptimizerComponent::Momentum { decay } => {
1066 direction *= 1.0 - decay;
1068 }
1069 OptimizerComponent::Regularization {
1070 l1_weight,
1071 l2_weight,
1072 ..
1073 } => {
1074 for i in 0..direction.len() {
1076 direction[i] += l1_weight * params[i].signum() + l2_weight * params[i];
1077 }
1078 }
1079 _ => {}
1080 }
1081 }
1082
1083 direction
1084 }
1085}
1086
1087#[allow(dead_code)]
1089pub fn nas_optimize<F>(
1090 objective: F,
1091 initial_params: &ArrayView1<f64>,
1092 config: Option<LearnedOptimizationConfig>,
1093) -> super::OptimizeResult<OptimizeResults<f64>>
1094where
1095 F: Fn(&ArrayView1<f64>) -> f64,
1096{
1097 let config = config.unwrap_or_default();
1098 let mut nas_system = AdaptiveNASSystem::new(config);
1099 nas_system.optimize(objective, initial_params)
1100}
1101
1102#[cfg(test)]
1103mod tests {
1104 use super::*;
1105
1106 #[test]
1107 fn test_nas_system_creation() {
1108 let config = LearnedOptimizationConfig::default();
1109 let nas_system = AdaptiveNASSystem::new(config);
1110
1111 assert_eq!(nas_system.generation, 0);
1112 assert!(nas_system.architecture_population.is_empty());
1113 }
1114
1115 #[test]
1116 fn test_architecture_generation() {
1117 let config = LearnedOptimizationConfig::default();
1118 let nas_system = AdaptiveNASSystem::new(config);
1119
1120 let architecture = nas_system.generate_random_architecture().unwrap();
1121
1122 assert!(!architecture.layers.is_empty());
1123 assert!(!architecture.activations.is_empty());
1124 assert!(!architecture.optimizer_components.is_empty());
1125 }
1126
1127 #[test]
1128 fn test_vocabulary_creation() {
1129 let vocab = ArchitectureVocabulary::new();
1130
1131 assert!(vocab.layer_types.contains_key("dense"));
1132 assert!(vocab.activations.contains_key("relu"));
1133 assert!(vocab.components.contains_key("momentum"));
1134 assert_eq!(vocab.vocab_size, 19);
1135 }
1136
1137 #[test]
1138 fn test_architecture_mutation() {
1139 let config = LearnedOptimizationConfig::default();
1140 let nas_system = AdaptiveNASSystem::new(config);
1141
1142 let base_arch = nas_system.generate_random_architecture().unwrap();
1143 let mutated = nas_system.mutate_architecture(&base_arch).unwrap();
1144
1145 assert_ne!(base_arch.id, mutated.id);
1146 }
1147
1148 #[test]
1149 fn test_nas_optimization() {
1150 let objective = |x: &ArrayView1<f64>| x[0].powi(2) + x[1].powi(2);
1151 let initial = Array1::from(vec![2.0, 2.0]);
1152
1153 let config = LearnedOptimizationConfig {
1154 meta_training_episodes: 5,
1155 inner_steps: 10,
1156 ..Default::default()
1157 };
1158
1159 let result = nas_optimize(objective, &initial.view(), Some(config)).unwrap();
1160
1161 assert!(result.fun >= 0.0);
1162 assert_eq!(result.x.len(), 2);
1163 assert!(result.success);
1164 }
1165}
1166
1167#[allow(dead_code)]
1168pub fn placeholder() {
1169 }