scirs2_optimize/learned_optimizers/
adaptive_nas_system.rs

1//! Adaptive Neural Architecture Search (NAS) System for Optimization
2//!
3//! This module implements sophisticated neural architecture search algorithms
4//! that adaptively design optimization strategies based on problem characteristics.
5//! The system can discover and evolve optimization architectures automatically.
6
7use super::{
8    ActivationType, LearnedOptimizationConfig, LearnedOptimizer, MetaOptimizerState,
9    OptimizationProblem, TrainingTask,
10};
11use crate::error::OptimizeResult;
12use crate::result::OptimizeResults;
13use scirs2_core::ndarray::{Array1, Array2, Array3, ArrayView1};
14use scirs2_core::random::Rng;
15use statrs::statistics::Statistics;
16use std::collections::HashMap;
17
18/// Advanced Neural Architecture Search System for Optimization
19#[derive(Debug, Clone)]
20pub struct AdaptiveNASSystem {
21    /// Configuration
22    config: LearnedOptimizationConfig,
23    /// Current architecture population
24    architecture_population: Vec<OptimizationArchitecture>,
25    /// Architecture performance history
26    performance_history: HashMap<ArchitectureId, Vec<f64>>,
27    /// Search controller network
28    controller: ArchitectureController,
29    /// Meta-optimizer state
30    meta_state: MetaOptimizerState,
31    /// Problem-specific architecture cache
32    architecture_cache: HashMap<String, OptimizationArchitecture>,
33    /// Search statistics
34    search_stats: NASSearchStats,
35    /// Current generation
36    generation: usize,
37}
38
39/// Unique identifier for architectures
40type ArchitectureId = String;
41
42/// Architecture for optimization algorithms
43#[derive(Debug, Clone)]
44pub struct OptimizationArchitecture {
45    /// Unique identifier
46    pub id: ArchitectureId,
47    /// Layer configuration
48    pub layers: Vec<LayerConfig>,
49    /// Connection pattern
50    pub connections: Vec<Connection>,
51    /// Activation functions
52    pub activations: Vec<ActivationType>,
53    /// Skip connections
54    pub skip_connections: Vec<SkipConnection>,
55    /// Optimization-specific components
56    pub optimizer_components: Vec<OptimizerComponent>,
57    /// Architecture complexity score
58    pub complexity: f64,
59    /// Performance metrics
60    pub performance_metrics: ArchitectureMetrics,
61}
62
63/// Layer configuration in the architecture
64#[derive(Debug, Clone)]
65pub struct LayerConfig {
66    /// Layer type
67    pub layer_type: LayerType,
68    /// Number of units/neurons
69    pub units: usize,
70    /// Dropout rate
71    pub dropout: f64,
72    /// Normalization type
73    pub normalization: NormalizationType,
74    /// Additional parameters
75    pub parameters: HashMap<String, f64>,
76}
77
78/// Types of layers in optimization architectures
79#[derive(Debug, Clone)]
80pub enum LayerType {
81    /// Dense/fully connected layer
82    Dense,
83    /// Convolutional layer (for structured problems)
84    Convolution { kernel_size: usize, stride: usize },
85    /// Attention layer
86    Attention { num_heads: usize },
87    /// LSTM layer
88    LSTM { hidden_size: usize },
89    /// GRU layer
90    GRU { hidden_size: usize },
91    /// Transformer block
92    Transformer { num_heads: usize, ff_dim: usize },
93    /// Graph neural network layer
94    GraphNN { aggregation: String },
95    /// Memory-augmented layer
96    Memory { memory_size: usize },
97}
98
99/// Types of normalization
100#[derive(Debug, Clone)]
101pub enum NormalizationType {
102    None,
103    BatchNorm,
104    LayerNorm,
105    GroupNorm { groups: usize },
106    InstanceNorm,
107}
108
109/// Connection between layers
110#[derive(Debug, Clone)]
111pub struct Connection {
112    /// Source layer index
113    pub from: usize,
114    /// Target layer index
115    pub to: usize,
116    /// Connection weight
117    pub weight: f64,
118    /// Connection type
119    pub connection_type: ConnectionType,
120}
121
122/// Types of connections
123#[derive(Debug, Clone)]
124pub enum ConnectionType {
125    /// Standard feedforward
126    Forward,
127    /// Residual connection
128    Residual,
129    /// Dense connection (from all previous layers)
130    Dense,
131    /// Attention-based connection
132    Attention,
133}
134
135/// Skip connection configuration
136#[derive(Debug, Clone)]
137pub struct SkipConnection {
138    /// Source layer
139    pub source: usize,
140    /// Target layer
141    pub target: usize,
142    /// Skip type
143    pub skip_type: SkipType,
144}
145
146/// Types of skip connections
147#[derive(Debug, Clone)]
148pub enum SkipType {
149    /// Simple addition
150    Add,
151    /// Concatenation
152    Concat,
153    /// Gated connection
154    Gated { gate_size: usize },
155    /// Highway connection
156    Highway,
157}
158
159/// Optimizer-specific components
160#[derive(Debug, Clone)]
161pub enum OptimizerComponent {
162    /// Momentum component
163    Momentum { decay: f64 },
164    /// Adaptive learning rate
165    AdaptiveLR {
166        adaptation_rate: f64,
167        min_lr: f64,
168        max_lr: f64,
169    },
170    /// Second-order approximation
171    SecondOrder {
172        hessian_approximation: HessianApprox,
173        regularization: f64,
174    },
175    /// Trust region component
176    TrustRegion {
177        initial_radius: f64,
178        max_radius: f64,
179        shrink_factor: f64,
180        expand_factor: f64,
181    },
182    /// Line search component
183    LineSearch {
184        method: LineSearchMethod,
185        max_nit: usize,
186    },
187    /// Regularization component
188    Regularization {
189        l1_weight: f64,
190        l2_weight: f64,
191        elastic_net_ratio: f64,
192    },
193}
194
195/// Hessian approximation methods
196#[derive(Debug, Clone)]
197pub enum HessianApprox {
198    BFGS,
199    LBFGS { memory_size: usize },
200    SR1,
201    DFP,
202    DiagonalApprox,
203}
204
205/// Line search methods
206#[derive(Debug, Clone)]
207pub enum LineSearchMethod {
208    Backtracking,
209    StrongWolfe,
210    MoreThuente,
211    Armijo,
212    Exact,
213}
214
215/// Performance metrics for architectures
216#[derive(Debug, Clone)]
217pub struct ArchitectureMetrics {
218    /// Average convergence rate
219    pub convergence_rate: f64,
220    /// Success rate on test problems
221    pub success_rate: f64,
222    /// Average function evaluations
223    pub avg_evaluations: f64,
224    /// Robustness score
225    pub robustness: f64,
226    /// Transfer learning capability
227    pub transfer_score: f64,
228    /// Computational efficiency
229    pub efficiency: f64,
230}
231
232impl Default for ArchitectureMetrics {
233    fn default() -> Self {
234        Self {
235            convergence_rate: 0.0,
236            success_rate: 0.0,
237            avg_evaluations: 0.0,
238            robustness: 0.0,
239            transfer_score: 0.0,
240            efficiency: 0.0,
241        }
242    }
243}
244
245/// Controller for generating architectures
246#[derive(Debug, Clone)]
247pub struct ArchitectureController {
248    /// LSTM-based controller network
249    lstm_weights: Array3<f64>,
250    /// Embedding layer for architecture components
251    embedding_layer: Array2<f64>,
252    /// Output layer for architecture decisions
253    output_layer: Array2<f64>,
254    /// Controller state
255    controller_state: Array1<f64>,
256    /// Vocabulary for architecture components
257    vocabulary: ArchitectureVocabulary,
258}
259
260/// Vocabulary for architecture search
261#[derive(Debug, Clone)]
262pub struct ArchitectureVocabulary {
263    /// Layer types mapping
264    pub layer_types: HashMap<String, usize>,
265    /// Activation functions mapping
266    pub activations: HashMap<String, usize>,
267    /// Optimizer components mapping
268    pub components: HashMap<String, usize>,
269    /// Total vocabulary size
270    pub vocab_size: usize,
271}
272
273/// Search statistics
274#[derive(Debug, Clone)]
275pub struct NASSearchStats {
276    /// Number of architectures evaluated
277    pub architectures_evaluated: usize,
278    /// Best performance found
279    pub best_performance: f64,
280    /// Search efficiency
281    pub search_efficiency: f64,
282    /// Diversity of population
283    pub population_diversity: f64,
284    /// Convergence indicators
285    pub convergence_indicators: Vec<f64>,
286}
287
288impl Default for NASSearchStats {
289    fn default() -> Self {
290        Self {
291            architectures_evaluated: 0,
292            best_performance: f64::NEG_INFINITY,
293            search_efficiency: 0.0,
294            population_diversity: 1.0,
295            convergence_indicators: Vec::new(),
296        }
297    }
298}
299
300impl AdaptiveNASSystem {
301    /// Create new adaptive NAS system
302    pub fn new(config: LearnedOptimizationConfig) -> Self {
303        let vocabulary = ArchitectureVocabulary::new();
304        let controller = ArchitectureController::new(&vocabulary, config.hidden_size);
305        let hidden_size = config.hidden_size;
306
307        Self {
308            config,
309            architecture_population: Vec::new(),
310            performance_history: HashMap::new(),
311            controller,
312            meta_state: MetaOptimizerState {
313                meta_params: Array1::zeros(100),
314                network_weights: Array2::zeros((hidden_size, hidden_size)),
315                performance_history: Vec::new(),
316                adaptation_stats: super::AdaptationStatistics::default(),
317                episode: 0,
318            },
319            architecture_cache: HashMap::new(),
320            search_stats: NASSearchStats::default(),
321            generation: 0,
322        }
323    }
324
325    /// Search for optimal architectures for given problems
326    pub fn search_architectures(
327        &mut self,
328        training_problems: &[OptimizationProblem],
329    ) -> OptimizeResult<Vec<OptimizationArchitecture>> {
330        // Initialize population if empty
331        if self.architecture_population.is_empty() {
332            self.initialize_population()?;
333        }
334
335        for generation in 0..self.config.meta_training_episodes {
336            self.generation = generation;
337
338            // Evaluate current population
339            self.evaluate_population(training_problems)?;
340
341            // Update controller based on performance
342            self.update_controller()?;
343
344            // Generate new architectures
345            let new_architectures = self.generate_new_architectures()?;
346
347            // Select best architectures for next generation
348            self.select_next_generation(new_architectures)?;
349
350            // Update search statistics
351            self.update_search_stats()?;
352
353            // Check convergence
354            if self.check_convergence() {
355                break;
356            }
357        }
358
359        Ok(self.get_best_architectures())
360    }
361
362    /// Initialize population with diverse architectures
363    fn initialize_population(&mut self) -> OptimizeResult<()> {
364        for _ in 0..self.config.batch_size {
365            let architecture = self.generate_random_architecture()?;
366            self.architecture_population.push(architecture);
367        }
368        Ok(())
369    }
370
371    /// Generate a random architecture
372    fn generate_random_architecture(&self) -> OptimizeResult<OptimizationArchitecture> {
373        let num_layers = 2 + (scirs2_core::random::rng().random_range(0..8)); // 2-10 layers
374        let mut layers = Vec::new();
375        let mut connections = Vec::new();
376        let mut activations = Vec::new();
377        let mut optimizer_components = Vec::new();
378
379        // Generate layers
380        for i in 0..num_layers {
381            let layer_type = self.sample_layer_type();
382            let units = 16 + (scirs2_core::random::rng().random_range(0..256)); // 16-272 units
383
384            layers.push(LayerConfig {
385                layer_type,
386                units,
387                dropout: scirs2_core::random::rng().random_range(0.0..0.5),
388                normalization: self.sample_normalization(),
389                parameters: HashMap::new(),
390            });
391
392            activations.push(self.sample_activation());
393
394            // Add connections (skip first layer)
395            if i > 0 {
396                connections.push(Connection {
397                    from: i - 1,
398                    to: i,
399                    weight: 1.0,
400                    connection_type: ConnectionType::Forward,
401                });
402
403                // Add skip connections with some probability
404                if i > 1 && scirs2_core::random::rng().random_range(0.0..1.0) < 0.3 {
405                    let skip_source = scirs2_core::random::rng().random_range(0..i);
406                    connections.push(Connection {
407                        from: skip_source,
408                        to: i,
409                        weight: 0.5,
410                        connection_type: ConnectionType::Residual,
411                    });
412                }
413            }
414        }
415
416        // Generate optimizer components
417        for _ in 0..(1 + scirs2_core::random::rng().random_range(0..4)) {
418            optimizer_components.push(self.sample_optimizer_component());
419        }
420
421        let id = format!(
422            "arch_{}",
423            scirs2_core::random::rng().random_range(0..u64::MAX)
424        );
425
426        Ok(OptimizationArchitecture {
427            id,
428            layers,
429            connections,
430            activations,
431            skip_connections: Vec::new(),
432            optimizer_components,
433            complexity: 0.0,
434            performance_metrics: ArchitectureMetrics::default(),
435        })
436    }
437
438    fn sample_layer_type(&self) -> LayerType {
439        match scirs2_core::random::rng().random_range(0..8) {
440            0 => LayerType::Dense,
441            1 => LayerType::Attention {
442                num_heads: 2 + scirs2_core::random::rng().random_range(0..6),
443            },
444            2 => LayerType::LSTM {
445                hidden_size: 32 + scirs2_core::random::rng().random_range(0..128),
446            },
447            3 => LayerType::GRU {
448                hidden_size: 32 + scirs2_core::random::rng().random_range(0..128),
449            },
450            4 => LayerType::Transformer {
451                num_heads: 2 + scirs2_core::random::rng().random_range(0..6),
452                ff_dim: 64 + scirs2_core::random::rng().random_range(0..256),
453            },
454            5 => LayerType::Memory {
455                memory_size: 16 + scirs2_core::random::rng().random_range(0..64),
456            },
457            6 => LayerType::Convolution {
458                kernel_size: 1 + scirs2_core::random::rng().random_range(0..5),
459                stride: 1 + scirs2_core::random::rng().random_range(0..3),
460            },
461            _ => LayerType::GraphNN {
462                aggregation: "mean".to_string(),
463            },
464        }
465    }
466
467    fn sample_normalization(&self) -> NormalizationType {
468        match scirs2_core::random::rng().random_range(0..5) {
469            0 => NormalizationType::None,
470            1 => NormalizationType::BatchNorm,
471            2 => NormalizationType::LayerNorm,
472            3 => NormalizationType::GroupNorm {
473                groups: 2 + scirs2_core::random::rng().random_range(0..6),
474            },
475            _ => NormalizationType::InstanceNorm,
476        }
477    }
478
479    fn sample_activation(&self) -> ActivationType {
480        match scirs2_core::random::rng().random_range(0..5) {
481            0 => ActivationType::ReLU,
482            1 => ActivationType::GELU,
483            2 => ActivationType::Swish,
484            3 => ActivationType::Tanh,
485            _ => ActivationType::LeakyReLU,
486        }
487    }
488
489    fn sample_optimizer_component(&self) -> OptimizerComponent {
490        match scirs2_core::random::rng().random_range(0..6) {
491            0 => OptimizerComponent::Momentum {
492                decay: 0.8 + scirs2_core::random::rng().random_range(0.0..0.19),
493            },
494            1 => OptimizerComponent::AdaptiveLR {
495                adaptation_rate: 0.001 + scirs2_core::random::rng().random_range(0.0..0.009),
496                min_lr: 1e-8,
497                max_lr: 1.0,
498            },
499            2 => OptimizerComponent::SecondOrder {
500                hessian_approximation: HessianApprox::LBFGS {
501                    memory_size: 5 + scirs2_core::random::rng().random_range(0..15),
502                },
503                regularization: 1e-6 + scirs2_core::random::rng().random_range(0.0..1e-3),
504            },
505            3 => OptimizerComponent::TrustRegion {
506                initial_radius: 0.1 + scirs2_core::random::rng().random_range(0.0..0.9),
507                max_radius: 10.0,
508                shrink_factor: 0.25,
509                expand_factor: 2.0,
510            },
511            4 => OptimizerComponent::LineSearch {
512                method: LineSearchMethod::StrongWolfe,
513                max_nit: 10 + scirs2_core::random::rng().random_range(0..20),
514            },
515            _ => OptimizerComponent::Regularization {
516                l1_weight: scirs2_core::random::rng().random_range(0.0..0.01),
517                l2_weight: scirs2_core::random::rng().random_range(0.0..0.01),
518                elastic_net_ratio: scirs2_core::random::rng().random_range(0.0..1.0),
519            },
520        }
521    }
522
523    /// Evaluate population on training problems
524    fn evaluate_population(
525        &mut self,
526        training_problems: &[OptimizationProblem],
527    ) -> OptimizeResult<()> {
528        // First evaluate all architectures
529        let scores: Vec<_> = self
530            .architecture_population
531            .iter()
532            .map(|architecture| {
533                let mut total_score = 0.0;
534                let mut num_evaluated = 0;
535
536                for problem in training_problems.iter().take(5) {
537                    // Limit for efficiency
538                    if let Ok(score) = self.evaluate_architecture_on_problem(architecture, problem)
539                    {
540                        total_score += score;
541                        num_evaluated += 1;
542                    }
543                }
544
545                if num_evaluated > 0 {
546                    Some(total_score / num_evaluated as f64)
547                } else {
548                    None
549                }
550            })
551            .collect();
552
553        // Now update architectures with their scores
554        for (architecture, score) in self.architecture_population.iter_mut().zip(scores.iter()) {
555            if let Some(avg_score) = score {
556                architecture.performance_metrics.convergence_rate = *avg_score;
557
558                // Update performance history
559                self.performance_history
560                    .entry(architecture.id.clone())
561                    .or_default()
562                    .push(*avg_score);
563            }
564        }
565
566        Ok(())
567    }
568
569    /// Evaluate single architecture on a problem
570    fn evaluate_architecture_on_problem(
571        &self,
572        architecture: &OptimizationArchitecture,
573        problem: &OptimizationProblem,
574    ) -> OptimizeResult<f64> {
575        // Simplified evaluation - in practice would build and test the actual architecture
576        let complexity_penalty = architecture.complexity * 0.01;
577        let num_components = architecture.optimizer_components.len() as f64;
578        let num_layers = architecture.layers.len() as f64;
579
580        // Heuristic scoring based on architecture properties
581        let base_score = (num_components * 0.1 + num_layers * 0.05).min(1.0);
582        let final_score = base_score - complexity_penalty;
583
584        Ok(final_score.max(0.0))
585    }
586
587    /// Update controller network
588    fn update_controller(&mut self) -> OptimizeResult<()> {
589        // Collect performance feedback
590        let mut rewards = Vec::new();
591        for arch in &self.architecture_population {
592            rewards.push(arch.performance_metrics.convergence_rate);
593        }
594
595        if rewards.is_empty() {
596            return Ok(());
597        }
598
599        // Update controller using REINFORCE-like algorithm
600        let baseline = rewards.iter().sum::<f64>() / rewards.len() as f64;
601
602        for (i, &reward) in rewards.iter().enumerate() {
603            let advantage = reward - baseline;
604
605            // Update controller weights (simplified)
606            let lstm_len = self.controller.lstm_weights.len();
607            if i < lstm_len {
608                let shape = self.controller.lstm_weights.shape();
609                let dims = (shape[0], shape[1], shape[2]);
610                for j in 0..dims.1 {
611                    for k in 0..dims.2 {
612                        let learning_rate = self.config.meta_learning_rate;
613                        let idx = (i % lstm_len, j, k);
614                        self.controller.lstm_weights[idx] += learning_rate * advantage * 0.01;
615                    }
616                }
617            }
618        }
619
620        Ok(())
621    }
622
623    /// Generate new architectures using controller
624    fn generate_new_architectures(&mut self) -> OptimizeResult<Vec<OptimizationArchitecture>> {
625        let mut new_architectures = Vec::new();
626
627        for _ in 0..self.config.batch_size / 2 {
628            // Generate architecture using controller
629            let architecture = self.controller_generate_architecture()?;
630            new_architectures.push(architecture);
631
632            // Also add mutated versions of best architectures
633            if !self.architecture_population.is_empty() {
634                let best_idx = self.get_best_architecture_index();
635                let mutated = self.mutate_architecture(&self.architecture_population[best_idx])?;
636                new_architectures.push(mutated);
637            }
638        }
639
640        Ok(new_architectures)
641    }
642
643    /// Generate architecture using controller network
644    fn controller_generate_architecture(&mut self) -> OptimizeResult<OptimizationArchitecture> {
645        // Simplified architecture generation using controller
646        // In practice, this would use the LSTM controller to generate architecture sequences
647
648        let mut architecture = self.generate_random_architecture()?;
649
650        // Modify based on controller state
651        let controller_influence = self.controller.controller_state.view().mean();
652
653        // Adjust architecture complexity based on controller
654        if controller_influence > 0.5 {
655            // Increase complexity
656            if architecture.layers.len() < 10 {
657                architecture.layers.push(LayerConfig {
658                    layer_type: LayerType::Dense,
659                    units: 64,
660                    dropout: 0.1,
661                    normalization: NormalizationType::LayerNorm,
662                    parameters: HashMap::new(),
663                });
664            }
665        } else {
666            // Reduce complexity
667            if architecture.layers.len() > 2 {
668                architecture.layers.pop();
669            }
670        }
671
672        Ok(architecture)
673    }
674
675    /// Mutate an existing architecture
676    fn mutate_architecture(
677        &self,
678        base_arch: &OptimizationArchitecture,
679    ) -> OptimizeResult<OptimizationArchitecture> {
680        let mut mutated = base_arch.clone();
681        mutated.id = format!(
682            "mutated_{}",
683            scirs2_core::random::rng().random_range(0..u64::MAX)
684        );
685
686        // Mutate with some probability
687        if scirs2_core::random::rng().random_range(0.0..1.0) < 0.3 {
688            // Mutate layer count
689            if scirs2_core::random::rng().random_range(0.0..1.0) < 0.5 && mutated.layers.len() < 12
690            {
691                mutated.layers.push(LayerConfig {
692                    layer_type: self.sample_layer_type(),
693                    units: 32 + scirs2_core::random::rng().random_range(0..128),
694                    dropout: scirs2_core::random::rng().random_range(0.0..0.5),
695                    normalization: self.sample_normalization(),
696                    parameters: HashMap::new(),
697                });
698            } else if mutated.layers.len() > 2 {
699                mutated.layers.pop();
700            }
701        }
702
703        // Mutate activations
704        for activation in &mut mutated.activations {
705            if scirs2_core::random::rng().random_range(0.0..1.0) < 0.2 {
706                *activation = self.sample_activation();
707            }
708        }
709
710        // Mutate optimizer components
711        if scirs2_core::random::rng().random_range(0.0..1.0) < 0.4 {
712            if scirs2_core::random::rng().random_range(0.0..1.0) < 0.5
713                && mutated.optimizer_components.len() < 6
714            {
715                mutated
716                    .optimizer_components
717                    .push(self.sample_optimizer_component());
718            } else if !mutated.optimizer_components.is_empty() {
719                let idx =
720                    scirs2_core::random::rng().random_range(0..mutated.optimizer_components.len());
721                mutated.optimizer_components.remove(idx);
722            }
723        }
724
725        Ok(mutated)
726    }
727
728    /// Select best architectures for next generation
729    fn select_next_generation(
730        &mut self,
731        mut new_architectures: Vec<OptimizationArchitecture>,
732    ) -> OptimizeResult<()> {
733        // Combine current population with new architectures
734        self.architecture_population.append(&mut new_architectures);
735
736        // Sort by performance
737        self.architecture_population.sort_by(|a, b| {
738            b.performance_metrics
739                .convergence_rate
740                .partial_cmp(&a.performance_metrics.convergence_rate)
741                .unwrap_or(std::cmp::Ordering::Equal)
742        });
743
744        // Keep only the best architectures
745        self.architecture_population
746            .truncate(self.config.batch_size);
747
748        Ok(())
749    }
750
751    /// Update search statistics
752    fn update_search_stats(&mut self) -> OptimizeResult<()> {
753        self.search_stats.architectures_evaluated += self.architecture_population.len();
754
755        if let Some(best_arch) = self.architecture_population.first() {
756            let best_performance = best_arch.performance_metrics.convergence_rate;
757            if best_performance > self.search_stats.best_performance {
758                self.search_stats.best_performance = best_performance;
759            }
760        }
761
762        // Compute population diversity
763        let performances: Vec<f64> = self
764            .architecture_population
765            .iter()
766            .map(|a| a.performance_metrics.convergence_rate)
767            .collect();
768
769        if performances.len() > 1 {
770            let mean = performances.iter().sum::<f64>() / performances.len() as f64;
771            let variance = performances
772                .iter()
773                .map(|&p| (p - mean).powi(2))
774                .sum::<f64>()
775                / performances.len() as f64;
776            self.search_stats.population_diversity = variance.sqrt();
777        }
778
779        self.search_stats
780            .convergence_indicators
781            .push(self.search_stats.best_performance);
782
783        Ok(())
784    }
785
786    /// Check if search has converged
787    fn check_convergence(&self) -> bool {
788        if self.search_stats.convergence_indicators.len() < 10 {
789            return false;
790        }
791
792        // Check if improvement has stagnated
793        let recent_improvements: Vec<f64> = self
794            .search_stats
795            .convergence_indicators
796            .windows(2)
797            .map(|w| w[1] - w[0])
798            .collect();
799
800        let avg_improvement =
801            recent_improvements.iter().sum::<f64>() / recent_improvements.len() as f64;
802        avg_improvement < 1e-6
803    }
804
805    /// Get best architectures from current population
806    fn get_best_architectures(&self) -> Vec<OptimizationArchitecture> {
807        self.architecture_population.clone()
808    }
809
810    fn get_best_architecture_index(&self) -> usize {
811        self.architecture_population
812            .iter()
813            .enumerate()
814            .max_by(|(_, a), (_, b)| {
815                a.performance_metrics
816                    .convergence_rate
817                    .partial_cmp(&b.performance_metrics.convergence_rate)
818                    .unwrap_or(std::cmp::Ordering::Equal)
819            })
820            .map(|(i, _)| i)
821            .unwrap_or(0)
822    }
823
824    /// Get search statistics
825    pub fn get_search_stats(&self) -> &NASSearchStats {
826        &self.search_stats
827    }
828
829    /// Cache architecture for specific problem type
830    pub fn cache_architecture_for_problem(
831        &mut self,
832        problem_class: String,
833        architecture: OptimizationArchitecture,
834    ) {
835        self.architecture_cache.insert(problem_class, architecture);
836    }
837
838    /// Retrieve cached architecture for problem type
839    pub fn get_cached_architecture(
840        &self,
841        problem_class: &str,
842    ) -> Option<&OptimizationArchitecture> {
843        self.architecture_cache.get(problem_class)
844    }
845}
846
847impl ArchitectureController {
848    /// Create new architecture controller
849    pub fn new(vocabulary: &ArchitectureVocabulary, hidden_size: usize) -> Self {
850        Self {
851            lstm_weights: Array3::from_shape_fn((4, hidden_size, hidden_size), |_| {
852                (scirs2_core::random::rng().random_range(0.0..1.0) - 0.5) * 0.1
853            }),
854            embedding_layer: Array2::from_shape_fn((hidden_size, vocabulary.vocab_size), |_| {
855                (scirs2_core::random::rng().random_range(0.0..1.0) - 0.5) * 0.1
856            }),
857            output_layer: Array2::from_shape_fn((vocabulary.vocab_size, hidden_size), |_| {
858                (scirs2_core::random::rng().random_range(0.0..1.0) - 0.5) * 0.1
859            }),
860            controller_state: Array1::zeros(hidden_size),
861            vocabulary: vocabulary.clone(),
862        }
863    }
864}
865
866impl Default for ArchitectureVocabulary {
867    fn default() -> Self {
868        Self::new()
869    }
870}
871
872impl ArchitectureVocabulary {
873    /// Create new architecture vocabulary
874    pub fn new() -> Self {
875        let mut layer_types = HashMap::new();
876        layer_types.insert("dense".to_string(), 0);
877        layer_types.insert("conv".to_string(), 1);
878        layer_types.insert("attention".to_string(), 2);
879        layer_types.insert("lstm".to_string(), 3);
880        layer_types.insert("gru".to_string(), 4);
881        layer_types.insert("transformer".to_string(), 5);
882        layer_types.insert("graph".to_string(), 6);
883        layer_types.insert("memory".to_string(), 7);
884
885        let mut activations = HashMap::new();
886        activations.insert("relu".to_string(), 8);
887        activations.insert("gelu".to_string(), 9);
888        activations.insert("swish".to_string(), 10);
889        activations.insert("tanh".to_string(), 11);
890        activations.insert("leaky_relu".to_string(), 12);
891
892        let mut components = HashMap::new();
893        components.insert("momentum".to_string(), 13);
894        components.insert("adaptive_lr".to_string(), 14);
895        components.insert("second_order".to_string(), 15);
896        components.insert("trust_region".to_string(), 16);
897        components.insert("line_search".to_string(), 17);
898        components.insert("regularization".to_string(), 18);
899
900        Self {
901            layer_types,
902            activations,
903            components,
904            vocab_size: 19,
905        }
906    }
907}
908
909impl LearnedOptimizer for AdaptiveNASSystem {
910    fn meta_train(&mut self, training_tasks: &[TrainingTask]) -> OptimizeResult<()> {
911        let problems: Vec<OptimizationProblem> = training_tasks
912            .iter()
913            .map(|task| task.problem.clone())
914            .collect();
915
916        self.search_architectures(&problems)?;
917        Ok(())
918    }
919
920    fn adapt_to_problem(
921        &mut self,
922        problem: &OptimizationProblem,
923        initial_params: &ArrayView1<f64>,
924    ) -> OptimizeResult<()> {
925        // Check if we have a cached architecture for this problem type
926        if let Some(cached_arch) = self.get_cached_architecture(&problem.problem_class) {
927            // Use cached architecture - no adaptation needed
928            return Ok(());
929        }
930
931        // Generate specialized architecture for this problem
932        let specialized_arch = self.generate_random_architecture()?;
933        self.cache_architecture_for_problem(problem.problem_class.clone(), specialized_arch);
934
935        Ok(())
936    }
937
938    fn optimize<F>(
939        &mut self,
940        objective: F,
941        initial_params: &ArrayView1<f64>,
942    ) -> OptimizeResult<OptimizeResults<f64>>
943    where
944        F: Fn(&ArrayView1<f64>) -> f64,
945    {
946        // Use best architecture to optimize
947        if self.architecture_population.is_empty() {
948            self.initialize_population()?;
949        }
950
951        let best_idx = self.get_best_architecture_index();
952        let best_arch = &self.architecture_population[best_idx];
953
954        // Simplified optimization using best architecture
955        let mut current_params = initial_params.to_owned();
956        let mut best_value = objective(initial_params);
957        let mut iterations = 0;
958
959        for iter in 0..1000 {
960            iterations = iter;
961
962            // Apply optimization step based on architecture
963            let step_size = self.compute_step_size(best_arch, iter);
964            let direction = self.compute_search_direction(&objective, &current_params, best_arch);
965
966            // Update parameters
967            for i in 0..current_params.len() {
968                current_params[i] -= step_size * direction[i];
969            }
970
971            let current_value = objective(&current_params.view());
972
973            if current_value < best_value {
974                best_value = current_value;
975            }
976
977            // Check convergence
978            if step_size < 1e-8 {
979                break;
980            }
981        }
982
983        Ok(OptimizeResults::<f64> {
984            x: current_params,
985            fun: best_value,
986            success: true,
987            nit: iterations,
988            message: format!(
989                "NAS optimization completed using architecture: {}",
990                best_arch.id
991            ),
992            jac: None,
993            hess: None,
994            constr: None,
995            nfev: iterations * best_arch.layers.len(), // Architecture depth affects evaluations
996            njev: 0,
997            nhev: 0,
998            maxcv: 0,
999            status: 0,
1000        })
1001    }
1002
1003    fn get_state(&self) -> &MetaOptimizerState {
1004        &self.meta_state
1005    }
1006
1007    fn reset(&mut self) {
1008        self.architecture_population.clear();
1009        self.performance_history.clear();
1010        self.search_stats = NASSearchStats::default();
1011        self.generation = 0;
1012    }
1013}
1014
1015impl AdaptiveNASSystem {
1016    fn compute_step_size(&self, architecture: &OptimizationArchitecture, iteration: usize) -> f64 {
1017        let mut step_size = 0.01;
1018
1019        // Adapt step size based on architecture components
1020        for component in &architecture.optimizer_components {
1021            match component {
1022                OptimizerComponent::AdaptiveLR {
1023                    adaptation_rate,
1024                    min_lr,
1025                    max_lr,
1026                } => {
1027                    step_size *= 1.0 + adaptation_rate * (iteration as f64).cos();
1028                    step_size = step_size.max(*min_lr).min(*max_lr);
1029                }
1030                OptimizerComponent::TrustRegion { initial_radius, .. } => {
1031                    step_size = step_size.min(*initial_radius);
1032                }
1033                _ => {}
1034            }
1035        }
1036
1037        step_size / (1.0 + iteration as f64 * 0.001)
1038    }
1039
1040    fn compute_search_direction<F>(
1041        &self,
1042        objective: &F,
1043        params: &Array1<f64>,
1044        architecture: &OptimizationArchitecture,
1045    ) -> Array1<f64>
1046    where
1047        F: Fn(&ArrayView1<f64>) -> f64,
1048    {
1049        let mut direction = Array1::zeros(params.len());
1050
1051        // Compute gradient (finite differences)
1052        let h = 1e-6;
1053        let f0 = objective(&params.view());
1054
1055        for i in 0..params.len() {
1056            let mut params_plus = params.clone();
1057            params_plus[i] += h;
1058            let f_plus = objective(&params_plus.view());
1059            direction[i] = (f_plus - f0) / h;
1060        }
1061
1062        // Apply architecture-specific modifications
1063        for component in &architecture.optimizer_components {
1064            match component {
1065                OptimizerComponent::Momentum { decay } => {
1066                    // Simple momentum approximation
1067                    direction *= 1.0 - decay;
1068                }
1069                OptimizerComponent::Regularization {
1070                    l1_weight,
1071                    l2_weight,
1072                    ..
1073                } => {
1074                    // Add regularization
1075                    for i in 0..direction.len() {
1076                        direction[i] += l1_weight * params[i].signum() + l2_weight * params[i];
1077                    }
1078                }
1079                _ => {}
1080            }
1081        }
1082
1083        direction
1084    }
1085}
1086
1087/// Convenience function for NAS-based optimization
1088#[allow(dead_code)]
1089pub fn nas_optimize<F>(
1090    objective: F,
1091    initial_params: &ArrayView1<f64>,
1092    config: Option<LearnedOptimizationConfig>,
1093) -> super::OptimizeResult<OptimizeResults<f64>>
1094where
1095    F: Fn(&ArrayView1<f64>) -> f64,
1096{
1097    let config = config.unwrap_or_default();
1098    let mut nas_system = AdaptiveNASSystem::new(config);
1099    nas_system.optimize(objective, initial_params)
1100}
1101
1102#[cfg(test)]
1103mod tests {
1104    use super::*;
1105
1106    #[test]
1107    fn test_nas_system_creation() {
1108        let config = LearnedOptimizationConfig::default();
1109        let nas_system = AdaptiveNASSystem::new(config);
1110
1111        assert_eq!(nas_system.generation, 0);
1112        assert!(nas_system.architecture_population.is_empty());
1113    }
1114
1115    #[test]
1116    fn test_architecture_generation() {
1117        let config = LearnedOptimizationConfig::default();
1118        let nas_system = AdaptiveNASSystem::new(config);
1119
1120        let architecture = nas_system.generate_random_architecture().unwrap();
1121
1122        assert!(!architecture.layers.is_empty());
1123        assert!(!architecture.activations.is_empty());
1124        assert!(!architecture.optimizer_components.is_empty());
1125    }
1126
1127    #[test]
1128    fn test_vocabulary_creation() {
1129        let vocab = ArchitectureVocabulary::new();
1130
1131        assert!(vocab.layer_types.contains_key("dense"));
1132        assert!(vocab.activations.contains_key("relu"));
1133        assert!(vocab.components.contains_key("momentum"));
1134        assert_eq!(vocab.vocab_size, 19);
1135    }
1136
1137    #[test]
1138    fn test_architecture_mutation() {
1139        let config = LearnedOptimizationConfig::default();
1140        let nas_system = AdaptiveNASSystem::new(config);
1141
1142        let base_arch = nas_system.generate_random_architecture().unwrap();
1143        let mutated = nas_system.mutate_architecture(&base_arch).unwrap();
1144
1145        assert_ne!(base_arch.id, mutated.id);
1146    }
1147
1148    #[test]
1149    fn test_nas_optimization() {
1150        let objective = |x: &ArrayView1<f64>| x[0].powi(2) + x[1].powi(2);
1151        let initial = Array1::from(vec![2.0, 2.0]);
1152
1153        let config = LearnedOptimizationConfig {
1154            meta_training_episodes: 5,
1155            inner_steps: 10,
1156            ..Default::default()
1157        };
1158
1159        let result = nas_optimize(objective, &initial.view(), Some(config)).unwrap();
1160
1161        assert!(result.fun >= 0.0);
1162        assert_eq!(result.x.len(), 2);
1163        assert!(result.success);
1164    }
1165}
1166
1167#[allow(dead_code)]
1168pub fn placeholder() {
1169    // Placeholder function to prevent unused module warnings
1170}