sklears_model_selection/
neural_architecture_search.rs

1//! Neural Architecture Search (NAS) Integration
2//!
3//! This module provides automated neural network architecture optimization using various
4//! search strategies including evolutionary algorithms, reinforcement learning, and
5//! gradient-based methods for finding optimal neural network architectures.
6
7use scirs2_core::rand_prelude::IndexedRandom;
8use scirs2_core::random::rngs::StdRng;
9use scirs2_core::random::Rng;
10use scirs2_core::random::SeedableRng;
11use sklears_core::types::Float;
12use std::collections::HashMap;
13
14/// Neural Architecture Search strategies
15#[derive(Debug, Clone)]
16pub enum NASStrategy {
17    /// Evolutionary algorithm-based search
18    Evolutionary {
19        population_size: usize,
20
21        generations: usize,
22
23        mutation_rate: Float,
24
25        crossover_rate: Float,
26    },
27    /// Reinforcement learning-based search
28    ReinforcementLearning {
29        episodes: usize,
30        learning_rate: Float,
31        exploration_rate: Float,
32    },
33    /// Gradient-based differentiable architecture search
34    GDAS {
35        search_epochs: usize,
36        learning_rate: Float,
37        weight_decay: Float,
38    },
39    /// Random search baseline
40    RandomSearch { n_trials: usize, max_depth: usize },
41    /// Progressive search with increasing complexity
42    Progressive {
43        stages: usize,
44        complexity_growth: Float,
45    },
46    /// Bayesian optimization for architecture search
47    BayesianOptimization {
48        n_trials: usize,
49        acquisition_function: String,
50    },
51}
52
53/// Neural network architecture representation
54#[derive(Debug, Clone)]
55pub struct NeuralArchitecture {
56    /// Number of layers
57    pub num_layers: usize,
58    /// Hidden layer sizes
59    pub layer_sizes: Vec<usize>,
60    /// Activation functions for each layer
61    pub activations: Vec<String>,
62    /// Dropout rates for each layer
63    pub dropout_rates: Vec<Float>,
64    /// Batch normalization flags
65    pub batch_norm: Vec<bool>,
66    /// Skip connections (ResNet-style)
67    pub skip_connections: Vec<(usize, usize)>,
68    /// Architecture complexity score
69    pub complexity_score: Float,
70}
71
72/// Architecture search space definition
73#[derive(Debug, Clone)]
74pub struct ArchitectureSearchSpace {
75    /// Range of layer counts
76    pub layer_count_range: (usize, usize),
77    /// Range of neurons per layer
78    pub neuron_count_range: (usize, usize),
79    /// Available activation functions
80    pub activation_options: Vec<String>,
81    /// Dropout rate range
82    pub dropout_range: (Float, Float),
83    /// Maximum skip connection distance
84    pub max_skip_distance: usize,
85    /// Whether to use batch normalization
86    pub use_batch_norm: bool,
87}
88
89/// NAS configuration
90#[derive(Debug, Clone)]
91pub struct NASConfig {
92    pub strategy: NASStrategy,
93    pub search_space: ArchitectureSearchSpace,
94    pub evaluation_metric: String,
95    pub max_evaluation_time: Option<u64>,
96    pub early_stopping_patience: usize,
97    pub validation_split: Float,
98    pub random_state: Option<u64>,
99    pub parallel_evaluations: usize,
100}
101
102/// Architecture evaluation result
103#[derive(Debug, Clone)]
104pub struct ArchitectureEvaluation {
105    pub architecture: NeuralArchitecture,
106    pub validation_score: Float,
107    pub training_time: Float,
108    pub parameters_count: usize,
109    pub flops: usize,
110    pub memory_usage: Float,
111}
112
113/// NAS optimization result
114#[derive(Debug, Clone)]
115pub struct NASResult {
116    pub best_architecture: NeuralArchitecture,
117    pub best_score: Float,
118    pub search_history: Vec<ArchitectureEvaluation>,
119    pub total_search_time: Float,
120    pub architectures_evaluated: usize,
121    pub convergence_curve: Vec<Float>,
122}
123
124/// Neural Architecture Search optimizer
125#[derive(Debug, Clone)]
126pub struct NASOptimizer {
127    config: NASConfig,
128    rng: StdRng,
129}
130
131impl Default for ArchitectureSearchSpace {
132    fn default() -> Self {
133        Self {
134            layer_count_range: (1, 10),
135            neuron_count_range: (16, 1024),
136            activation_options: vec![
137                "relu".to_string(),
138                "tanh".to_string(),
139                "sigmoid".to_string(),
140                "swish".to_string(),
141                "gelu".to_string(),
142            ],
143            dropout_range: (0.0, 0.5),
144            max_skip_distance: 3,
145            use_batch_norm: true,
146        }
147    }
148}
149
150impl Default for NASConfig {
151    fn default() -> Self {
152        Self {
153            strategy: NASStrategy::Evolutionary {
154                population_size: 20,
155                generations: 50,
156                mutation_rate: 0.1,
157                crossover_rate: 0.7,
158            },
159            search_space: ArchitectureSearchSpace::default(),
160            evaluation_metric: "accuracy".to_string(),
161            max_evaluation_time: Some(3600), // 1 hour
162            early_stopping_patience: 10,
163            validation_split: 0.2,
164            random_state: None,
165            parallel_evaluations: 4,
166        }
167    }
168}
169
170impl NASOptimizer {
171    /// Create a new NAS optimizer
172    pub fn new(config: NASConfig) -> Self {
173        let rng = match config.random_state {
174            Some(seed) => StdRng::seed_from_u64(seed),
175            None => {
176                use scirs2_core::random::thread_rng;
177                StdRng::from_rng(&mut thread_rng())
178            }
179        };
180
181        Self { config, rng }
182    }
183
184    /// Search for optimal neural architecture
185    pub fn search<F>(&mut self, evaluation_fn: F) -> Result<NASResult, Box<dyn std::error::Error>>
186    where
187        F: Fn(&NeuralArchitecture) -> Result<ArchitectureEvaluation, Box<dyn std::error::Error>>,
188    {
189        let start_time = std::time::Instant::now();
190
191        let result = match &self.config.strategy {
192            NASStrategy::Evolutionary { .. } => self.evolutionary_search(&evaluation_fn)?,
193            NASStrategy::ReinforcementLearning { .. } => {
194                self.reinforcement_learning_search(&evaluation_fn)?
195            }
196            NASStrategy::GDAS { .. } => self.gradient_based_search(&evaluation_fn)?,
197            NASStrategy::RandomSearch { .. } => self.random_search(&evaluation_fn)?,
198            NASStrategy::Progressive { .. } => self.progressive_search(&evaluation_fn)?,
199            NASStrategy::BayesianOptimization { .. } => {
200                self.bayesian_optimization_search(&evaluation_fn)?
201            }
202        };
203
204        let total_time = start_time.elapsed().as_secs_f64() as Float;
205
206        Ok(NASResult {
207            best_architecture: result.0,
208            best_score: result.1,
209            search_history: result.2,
210            total_search_time: total_time,
211            architectures_evaluated: result.3,
212            convergence_curve: result.4,
213        })
214    }
215
216    /// Evolutionary algorithm-based architecture search
217    fn evolutionary_search<F>(
218        &mut self,
219        evaluation_fn: &F,
220    ) -> Result<
221        (
222            NeuralArchitecture,
223            Float,
224            Vec<ArchitectureEvaluation>,
225            usize,
226            Vec<Float>,
227        ),
228        Box<dyn std::error::Error>,
229    >
230    where
231        F: Fn(&NeuralArchitecture) -> Result<ArchitectureEvaluation, Box<dyn std::error::Error>>,
232    {
233        let (population_size, generations, mutation_rate, crossover_rate) =
234            match &self.config.strategy {
235                NASStrategy::Evolutionary {
236                    population_size,
237                    generations,
238                    mutation_rate,
239                    crossover_rate,
240                } => (
241                    *population_size,
242                    *generations,
243                    *mutation_rate,
244                    *crossover_rate,
245                ),
246                _ => unreachable!(),
247            };
248
249        let mut population = self.initialize_population(population_size)?;
250        let mut search_history = Vec::new();
251        let mut convergence_curve = Vec::new();
252        let mut best_architecture = population[0].clone();
253        let mut best_score = Float::NEG_INFINITY;
254        let mut evaluations_count = 0;
255
256        for _generation in 0..generations {
257            // Evaluate population
258            let mut evaluations = Vec::new();
259            for architecture in &population {
260                let evaluation = evaluation_fn(architecture)?;
261                evaluations.push(evaluation.clone());
262                search_history.push(evaluation.clone());
263                evaluations_count += 1;
264
265                if evaluation.validation_score > best_score {
266                    best_score = evaluation.validation_score;
267                    best_architecture = architecture.clone();
268                }
269            }
270
271            // Selection, crossover, and mutation
272            population =
273                self.evolve_population(&population, &evaluations, crossover_rate, mutation_rate)?;
274
275            // Track convergence
276            let generation_best = evaluations
277                .iter()
278                .map(|e| e.validation_score)
279                .fold(Float::NEG_INFINITY, |a, b| a.max(b));
280            convergence_curve.push(generation_best);
281
282            // Early stopping check
283            if self.check_early_stopping(&convergence_curve) {
284                break;
285            }
286        }
287
288        Ok((
289            best_architecture,
290            best_score,
291            search_history,
292            evaluations_count,
293            convergence_curve,
294        ))
295    }
296
297    /// Reinforcement learning-based architecture search
298    fn reinforcement_learning_search<F>(
299        &mut self,
300        evaluation_fn: &F,
301    ) -> Result<
302        (
303            NeuralArchitecture,
304            Float,
305            Vec<ArchitectureEvaluation>,
306            usize,
307            Vec<Float>,
308        ),
309        Box<dyn std::error::Error>,
310    >
311    where
312        F: Fn(&NeuralArchitecture) -> Result<ArchitectureEvaluation, Box<dyn std::error::Error>>,
313    {
314        let (episodes, learning_rate, exploration_rate) = match &self.config.strategy {
315            NASStrategy::ReinforcementLearning {
316                episodes,
317                learning_rate,
318                exploration_rate,
319            } => (*episodes, *learning_rate, *exploration_rate),
320            _ => unreachable!(),
321        };
322
323        let mut search_history = Vec::new();
324        let mut convergence_curve = Vec::new();
325        let mut best_architecture = self.generate_random_architecture()?;
326        let mut best_score = Float::NEG_INFINITY;
327        let mut evaluations_count = 0;
328
329        // Simple policy-based RL approach
330        let mut policy_weights = HashMap::new();
331        let mut epsilon = exploration_rate;
332
333        for _episode in 0..episodes {
334            // Generate architecture using current policy
335            let architecture = if self.rng.random::<Float>() < epsilon {
336                self.generate_random_architecture()?
337            } else {
338                self.generate_architecture_from_policy(&policy_weights)?
339            };
340
341            // Evaluate architecture
342            let evaluation = evaluation_fn(&architecture)?;
343            search_history.push(evaluation.clone());
344            evaluations_count += 1;
345
346            // Update policy weights based on reward
347            let reward = evaluation.validation_score;
348            self.update_policy_weights(&mut policy_weights, &architecture, reward, learning_rate);
349
350            if evaluation.validation_score > best_score {
351                best_score = evaluation.validation_score;
352                best_architecture = architecture.clone();
353            }
354
355            convergence_curve.push(best_score);
356
357            // Decay exploration rate
358            epsilon *= 0.99;
359
360            // Early stopping check
361            if self.check_early_stopping(&convergence_curve) {
362                break;
363            }
364        }
365
366        Ok((
367            best_architecture,
368            best_score,
369            search_history,
370            evaluations_count,
371            convergence_curve,
372        ))
373    }
374
375    /// Gradient-based differentiable architecture search
376    fn gradient_based_search<F>(
377        &mut self,
378        evaluation_fn: &F,
379    ) -> Result<
380        (
381            NeuralArchitecture,
382            Float,
383            Vec<ArchitectureEvaluation>,
384            usize,
385            Vec<Float>,
386        ),
387        Box<dyn std::error::Error>,
388    >
389    where
390        F: Fn(&NeuralArchitecture) -> Result<ArchitectureEvaluation, Box<dyn std::error::Error>>,
391    {
392        let (search_epochs, learning_rate, _weight_decay) = match &self.config.strategy {
393            NASStrategy::GDAS {
394                search_epochs,
395                learning_rate,
396                weight_decay,
397            } => (*search_epochs, *learning_rate, *weight_decay),
398            _ => unreachable!(),
399        };
400
401        let mut search_history = Vec::new();
402        let mut convergence_curve = Vec::new();
403        let mut best_architecture = self.generate_random_architecture()?;
404        let mut best_score = Float::NEG_INFINITY;
405        let mut evaluations_count = 0;
406
407        // Simulate gradient-based search with architecture parameters
408        let mut architecture_params = self.initialize_architecture_parameters()?;
409
410        for _epoch in 0..search_epochs {
411            // Sample architecture from current parameters
412            let architecture = self.sample_architecture_from_params(&architecture_params)?;
413
414            // Evaluate architecture
415            let evaluation = evaluation_fn(&architecture)?;
416            search_history.push(evaluation.clone());
417            evaluations_count += 1;
418
419            // Update architecture parameters (simplified gradient update)
420            self.update_architecture_parameters(
421                &mut architecture_params,
422                &evaluation,
423                learning_rate,
424            );
425
426            if evaluation.validation_score > best_score {
427                best_score = evaluation.validation_score;
428                best_architecture = architecture.clone();
429            }
430
431            convergence_curve.push(best_score);
432
433            // Early stopping check
434            if self.check_early_stopping(&convergence_curve) {
435                break;
436            }
437        }
438
439        Ok((
440            best_architecture,
441            best_score,
442            search_history,
443            evaluations_count,
444            convergence_curve,
445        ))
446    }
447
448    /// Random search baseline
449    fn random_search<F>(
450        &mut self,
451        evaluation_fn: &F,
452    ) -> Result<
453        (
454            NeuralArchitecture,
455            Float,
456            Vec<ArchitectureEvaluation>,
457            usize,
458            Vec<Float>,
459        ),
460        Box<dyn std::error::Error>,
461    >
462    where
463        F: Fn(&NeuralArchitecture) -> Result<ArchitectureEvaluation, Box<dyn std::error::Error>>,
464    {
465        let (n_trials, _max_depth) = match &self.config.strategy {
466            NASStrategy::RandomSearch {
467                n_trials,
468                max_depth,
469            } => (*n_trials, *max_depth),
470            _ => unreachable!(),
471        };
472
473        let mut search_history = Vec::new();
474        let mut convergence_curve = Vec::new();
475        let mut best_architecture = self.generate_random_architecture()?;
476        let mut best_score = Float::NEG_INFINITY;
477
478        for _trial in 0..n_trials {
479            let architecture = self.generate_random_architecture()?;
480            let evaluation = evaluation_fn(&architecture)?;
481            search_history.push(evaluation.clone());
482
483            if evaluation.validation_score > best_score {
484                best_score = evaluation.validation_score;
485                best_architecture = architecture.clone();
486            }
487
488            convergence_curve.push(best_score);
489
490            // Early stopping check
491            if self.check_early_stopping(&convergence_curve) {
492                break;
493            }
494        }
495
496        Ok((
497            best_architecture,
498            best_score,
499            search_history,
500            n_trials,
501            convergence_curve,
502        ))
503    }
504
505    /// Progressive search with increasing complexity
506    fn progressive_search<F>(
507        &mut self,
508        evaluation_fn: &F,
509    ) -> Result<
510        (
511            NeuralArchitecture,
512            Float,
513            Vec<ArchitectureEvaluation>,
514            usize,
515            Vec<Float>,
516        ),
517        Box<dyn std::error::Error>,
518    >
519    where
520        F: Fn(&NeuralArchitecture) -> Result<ArchitectureEvaluation, Box<dyn std::error::Error>>,
521    {
522        let (stages, complexity_growth) = match &self.config.strategy {
523            NASStrategy::Progressive {
524                stages,
525                complexity_growth,
526            } => (*stages, *complexity_growth),
527            _ => unreachable!(),
528        };
529
530        let mut search_history = Vec::new();
531        let mut convergence_curve = Vec::new();
532        let mut best_architecture = self.generate_random_architecture()?;
533        let mut best_score = Float::NEG_INFINITY;
534        let mut evaluations_count = 0;
535
536        let mut current_complexity = 1.0;
537        let trials_per_stage = 20;
538
539        for _stage in 0..stages {
540            // Search with current complexity constraint
541            for _ in 0..trials_per_stage {
542                let architecture =
543                    self.generate_architecture_with_complexity(current_complexity)?;
544                let evaluation = evaluation_fn(&architecture)?;
545                search_history.push(evaluation.clone());
546                evaluations_count += 1;
547
548                if evaluation.validation_score > best_score {
549                    best_score = evaluation.validation_score;
550                    best_architecture = architecture.clone();
551                }
552
553                convergence_curve.push(best_score);
554            }
555
556            // Increase complexity for next stage
557            current_complexity *= complexity_growth;
558
559            // Early stopping check
560            if self.check_early_stopping(&convergence_curve) {
561                break;
562            }
563        }
564
565        Ok((
566            best_architecture,
567            best_score,
568            search_history,
569            evaluations_count,
570            convergence_curve,
571        ))
572    }
573
574    /// Bayesian optimization for architecture search
575    fn bayesian_optimization_search<F>(
576        &mut self,
577        evaluation_fn: &F,
578    ) -> Result<
579        (
580            NeuralArchitecture,
581            Float,
582            Vec<ArchitectureEvaluation>,
583            usize,
584            Vec<Float>,
585        ),
586        Box<dyn std::error::Error>,
587    >
588    where
589        F: Fn(&NeuralArchitecture) -> Result<ArchitectureEvaluation, Box<dyn std::error::Error>>,
590    {
591        let (n_trials, _acquisition_function) = match &self.config.strategy {
592            NASStrategy::BayesianOptimization {
593                n_trials,
594                acquisition_function,
595            } => (*n_trials, acquisition_function),
596            _ => unreachable!(),
597        };
598
599        let mut search_history = Vec::new();
600        let mut convergence_curve = Vec::new();
601        let mut best_architecture = self.generate_random_architecture()?;
602        let mut best_score = Float::NEG_INFINITY;
603
604        // Initialize with random samples
605        let init_samples = 5;
606        let mut evaluated_architectures = Vec::new();
607
608        for _ in 0..init_samples {
609            let architecture = self.generate_random_architecture()?;
610            let evaluation = evaluation_fn(&architecture)?;
611            search_history.push(evaluation.clone());
612            evaluated_architectures.push((architecture.clone(), evaluation.validation_score));
613
614            if evaluation.validation_score > best_score {
615                best_score = evaluation.validation_score;
616                best_architecture = architecture.clone();
617            }
618
619            convergence_curve.push(best_score);
620        }
621
622        // Bayesian optimization loop
623        for _ in init_samples..n_trials {
624            // Select next architecture using acquisition function
625            let architecture = self.select_next_architecture_bayesian(&evaluated_architectures)?;
626            let evaluation = evaluation_fn(&architecture)?;
627            search_history.push(evaluation.clone());
628            evaluated_architectures.push((architecture.clone(), evaluation.validation_score));
629
630            if evaluation.validation_score > best_score {
631                best_score = evaluation.validation_score;
632                best_architecture = architecture.clone();
633            }
634
635            convergence_curve.push(best_score);
636
637            // Early stopping check
638            if self.check_early_stopping(&convergence_curve) {
639                break;
640            }
641        }
642
643        Ok((
644            best_architecture,
645            best_score,
646            search_history,
647            n_trials,
648            convergence_curve,
649        ))
650    }
651
652    /// Generate random neural architecture
653    fn generate_random_architecture(
654        &mut self,
655    ) -> Result<NeuralArchitecture, Box<dyn std::error::Error>> {
656        let num_layers = self.rng.gen_range(
657            self.config.search_space.layer_count_range.0
658                ..=self.config.search_space.layer_count_range.1,
659        );
660
661        let mut layer_sizes = Vec::new();
662        let mut activations = Vec::new();
663        let mut dropout_rates = Vec::new();
664        let mut batch_norm = Vec::new();
665
666        for _ in 0..num_layers {
667            let size = self.rng.gen_range(
668                self.config.search_space.neuron_count_range.0
669                    ..=self.config.search_space.neuron_count_range.1,
670            );
671            layer_sizes.push(size);
672
673            let activation = self
674                .config
675                .search_space
676                .activation_options
677                .choose(&mut self.rng)
678                .unwrap()
679                .clone();
680            activations.push(activation);
681
682            let dropout = self.rng.gen_range(
683                self.config.search_space.dropout_range.0..=self.config.search_space.dropout_range.1,
684            );
685            dropout_rates.push(dropout);
686
687            batch_norm.push(self.config.search_space.use_batch_norm && self.rng.gen_bool(0.5));
688        }
689
690        // Generate skip connections
691        let mut skip_connections = Vec::new();
692        if num_layers > 2 {
693            let n_skip = self.rng.gen_range(0..num_layers / 2 + 1);
694            for _ in 0..n_skip {
695                let from = self.rng.gen_range(0..num_layers - 1);
696                let max_to =
697                    (from + self.config.search_space.max_skip_distance).min(num_layers - 1);
698                if max_to > from {
699                    let to = self.rng.gen_range(from + 1..max_to + 1);
700                    skip_connections.push((from, to));
701                }
702            }
703        }
704
705        let complexity_score = self.calculate_complexity_score(&layer_sizes, &skip_connections);
706
707        Ok(NeuralArchitecture {
708            num_layers,
709            layer_sizes,
710            activations,
711            dropout_rates,
712            batch_norm,
713            skip_connections,
714            complexity_score,
715        })
716    }
717
718    /// Initialize population for evolutionary algorithm
719    fn initialize_population(
720        &mut self,
721        size: usize,
722    ) -> Result<Vec<NeuralArchitecture>, Box<dyn std::error::Error>> {
723        let mut population = Vec::new();
724        for _ in 0..size {
725            population.push(self.generate_random_architecture()?);
726        }
727        Ok(population)
728    }
729
730    /// Evolve population using genetic operators
731    fn evolve_population(
732        &mut self,
733        population: &[NeuralArchitecture],
734        evaluations: &[ArchitectureEvaluation],
735        crossover_rate: Float,
736        mutation_rate: Float,
737    ) -> Result<Vec<NeuralArchitecture>, Box<dyn std::error::Error>> {
738        let mut new_population = Vec::new();
739        let population_size = population.len();
740
741        // Keep best individuals (elitism)
742        let mut sorted_indices: Vec<usize> = (0..population_size).collect();
743        sorted_indices.sort_by(|&a, &b| {
744            evaluations[b]
745                .validation_score
746                .partial_cmp(&evaluations[a].validation_score)
747                .unwrap()
748        });
749
750        let elite_count = population_size / 4;
751        for &idx in sorted_indices.iter().take(elite_count) {
752            new_population.push(population[idx].clone());
753        }
754
755        // Generate offspring through crossover and mutation
756        while new_population.len() < population_size {
757            // Extract all random values first to avoid multiple mutable borrows
758            let crossover_prob = self.rng.random::<Float>();
759            let mutation_prob = self.rng.random::<Float>();
760
761            let (parent1, parent2) = self.tournament_selection_pair(population, evaluations, 3)?;
762
763            let mut offspring = if crossover_prob < crossover_rate {
764                self.crossover(parent1, parent2)?
765            } else {
766                parent1.clone()
767            };
768
769            if mutation_prob < mutation_rate {
770                offspring = self.mutate(&offspring)?;
771            }
772
773            new_population.push(offspring);
774        }
775
776        Ok(new_population)
777    }
778
779    /// Tournament selection for evolutionary algorithm
780    fn tournament_selection<'a>(
781        &mut self,
782        population: &'a [NeuralArchitecture],
783        evaluations: &[ArchitectureEvaluation],
784        tournament_size: usize,
785    ) -> Result<&'a NeuralArchitecture, Box<dyn std::error::Error>> {
786        let mut best_idx = 0;
787        let mut best_score = Float::NEG_INFINITY;
788
789        for _ in 0..tournament_size {
790            let idx = self.rng.gen_range(0..population.len());
791            if evaluations[idx].validation_score > best_score {
792                best_score = evaluations[idx].validation_score;
793                best_idx = idx;
794            }
795        }
796
797        Ok(&population[best_idx])
798    }
799
800    /// Tournament selection for two parents
801    fn tournament_selection_pair<'a>(
802        &mut self,
803        population: &'a [NeuralArchitecture],
804        evaluations: &[ArchitectureEvaluation],
805        tournament_size: usize,
806    ) -> Result<(&'a NeuralArchitecture, &'a NeuralArchitecture), Box<dyn std::error::Error>> {
807        let parent1_idx =
808            self.tournament_selection_idx(population, evaluations, tournament_size)?;
809        let parent2_idx =
810            self.tournament_selection_idx(population, evaluations, tournament_size)?;
811
812        Ok((&population[parent1_idx], &population[parent2_idx]))
813    }
814
815    /// Tournament selection returning index
816    fn tournament_selection_idx(
817        &mut self,
818        population: &[NeuralArchitecture],
819        evaluations: &[ArchitectureEvaluation],
820        tournament_size: usize,
821    ) -> Result<usize, Box<dyn std::error::Error>> {
822        let mut best_idx = 0;
823        let mut best_score = Float::NEG_INFINITY;
824
825        for _ in 0..tournament_size {
826            let idx = self.rng.gen_range(0..population.len());
827            if evaluations[idx].validation_score > best_score {
828                best_score = evaluations[idx].validation_score;
829                best_idx = idx;
830            }
831        }
832
833        Ok(best_idx)
834    }
835
836    /// Crossover two architectures
837    fn crossover(
838        &mut self,
839        parent1: &NeuralArchitecture,
840        parent2: &NeuralArchitecture,
841    ) -> Result<NeuralArchitecture, Box<dyn std::error::Error>> {
842        let num_layers = if self.rng.gen_bool(0.5) {
843            parent1.num_layers
844        } else {
845            parent2.num_layers
846        };
847
848        let mut layer_sizes = Vec::new();
849        let mut activations = Vec::new();
850        let mut dropout_rates = Vec::new();
851        let mut batch_norm = Vec::new();
852
853        for i in 0..num_layers {
854            let parent1_idx = i.min(parent1.num_layers - 1);
855            let parent2_idx = i.min(parent2.num_layers - 1);
856
857            layer_sizes.push(if self.rng.gen_bool(0.5) {
858                parent1.layer_sizes[parent1_idx]
859            } else {
860                parent2.layer_sizes[parent2_idx]
861            });
862
863            activations.push(if self.rng.gen_bool(0.5) {
864                parent1.activations[parent1_idx].clone()
865            } else {
866                parent2.activations[parent2_idx].clone()
867            });
868
869            dropout_rates.push(if self.rng.gen_bool(0.5) {
870                parent1.dropout_rates[parent1_idx]
871            } else {
872                parent2.dropout_rates[parent2_idx]
873            });
874
875            batch_norm.push(if self.rng.gen_bool(0.5) {
876                parent1.batch_norm[parent1_idx]
877            } else {
878                parent2.batch_norm[parent2_idx]
879            });
880        }
881
882        // Combine skip connections
883        let mut skip_connections = Vec::new();
884        for &(from, to) in &parent1.skip_connections {
885            if from < num_layers && to < num_layers {
886                skip_connections.push((from, to));
887            }
888        }
889        for &(from, to) in &parent2.skip_connections {
890            if from < num_layers && to < num_layers && !skip_connections.contains(&(from, to)) {
891                skip_connections.push((from, to));
892            }
893        }
894
895        let complexity_score = self.calculate_complexity_score(&layer_sizes, &skip_connections);
896
897        Ok(NeuralArchitecture {
898            num_layers,
899            layer_sizes,
900            activations,
901            dropout_rates,
902            batch_norm,
903            skip_connections,
904            complexity_score,
905        })
906    }
907
908    /// Mutate an architecture
909    fn mutate(
910        &mut self,
911        architecture: &NeuralArchitecture,
912    ) -> Result<NeuralArchitecture, Box<dyn std::error::Error>> {
913        let mut mutated = architecture.clone();
914
915        // Mutate number of layers
916        if self.rng.gen_bool(0.1) {
917            let change = if self.rng.gen_bool(0.5) { 1 } else { -1 };
918            mutated.num_layers = ((mutated.num_layers as i32 + change) as usize)
919                .max(self.config.search_space.layer_count_range.0)
920                .min(self.config.search_space.layer_count_range.1);
921        }
922
923        // Adjust vectors to match new layer count
924        while mutated.layer_sizes.len() < mutated.num_layers {
925            mutated.layer_sizes.push(self.rng.gen_range(
926                self.config.search_space.neuron_count_range.0
927                    ..=self.config.search_space.neuron_count_range.1,
928            ));
929            mutated.activations.push(
930                self.config
931                    .search_space
932                    .activation_options
933                    .choose(&mut self.rng)
934                    .unwrap()
935                    .clone(),
936            );
937            mutated.dropout_rates.push(self.rng.gen_range(
938                self.config.search_space.dropout_range.0..=self.config.search_space.dropout_range.1,
939            ));
940            mutated
941                .batch_norm
942                .push(self.config.search_space.use_batch_norm && self.rng.gen_bool(0.5));
943        }
944        mutated.layer_sizes.truncate(mutated.num_layers);
945        mutated.activations.truncate(mutated.num_layers);
946        mutated.dropout_rates.truncate(mutated.num_layers);
947        mutated.batch_norm.truncate(mutated.num_layers);
948
949        // Mutate layer sizes
950        for size in &mut mutated.layer_sizes {
951            if self.rng.gen_bool(0.2) {
952                let change = self.rng.gen_range(-50..50 + 1);
953                *size = ((*size as i32 + change) as usize)
954                    .max(self.config.search_space.neuron_count_range.0)
955                    .min(self.config.search_space.neuron_count_range.1);
956            }
957        }
958
959        // Mutate activations
960        for activation in &mut mutated.activations {
961            if self.rng.gen_bool(0.1) {
962                *activation = self
963                    .config
964                    .search_space
965                    .activation_options
966                    .choose(&mut self.rng)
967                    .unwrap()
968                    .clone();
969            }
970        }
971
972        // Mutate dropout rates
973        for dropout in &mut mutated.dropout_rates {
974            if self.rng.gen_bool(0.2) {
975                let change = self.rng.gen_range(-0.1..1.1);
976                *dropout = (*dropout + change)
977                    .max(self.config.search_space.dropout_range.0)
978                    .min(self.config.search_space.dropout_range.1);
979            }
980        }
981
982        // Mutate batch normalization
983        for bn in &mut mutated.batch_norm {
984            if self.rng.gen_bool(0.1) {
985                *bn = !*bn;
986            }
987        }
988
989        // Mutate skip connections
990        if self.rng.gen_bool(0.1) {
991            if self.rng.gen_bool(0.5) && !mutated.skip_connections.is_empty() {
992                // Remove a connection
993                let idx = self.rng.gen_range(0..mutated.skip_connections.len());
994                mutated.skip_connections.remove(idx);
995            } else if mutated.num_layers > 2 {
996                // Add a connection
997                let from = self.rng.gen_range(0..mutated.num_layers - 1);
998                let max_to =
999                    (from + self.config.search_space.max_skip_distance).min(mutated.num_layers - 1);
1000                if max_to > from {
1001                    let to = self.rng.gen_range(from + 1..max_to + 1);
1002                    let connection = (from, to);
1003                    if !mutated.skip_connections.contains(&connection) {
1004                        mutated.skip_connections.push(connection);
1005                    }
1006                }
1007            }
1008        }
1009
1010        mutated.complexity_score =
1011            self.calculate_complexity_score(&mutated.layer_sizes, &mutated.skip_connections);
1012
1013        Ok(mutated)
1014    }
1015
1016    /// Calculate architecture complexity score
1017    fn calculate_complexity_score(
1018        &self,
1019        layer_sizes: &[usize],
1020        skip_connections: &[(usize, usize)],
1021    ) -> Float {
1022        let total_params = layer_sizes.iter().sum::<usize>() as Float;
1023        let skip_penalty = skip_connections.len() as Float * 0.1;
1024        total_params + skip_penalty
1025    }
1026
1027    /// Check early stopping condition
1028    fn check_early_stopping(&self, convergence_curve: &[Float]) -> bool {
1029        if convergence_curve.len() < self.config.early_stopping_patience {
1030            return false;
1031        }
1032
1033        let recent_scores =
1034            &convergence_curve[convergence_curve.len() - self.config.early_stopping_patience..];
1035        let improvement = recent_scores.last().unwrap() - recent_scores.first().unwrap();
1036        improvement < 1e-6
1037    }
1038
1039    /// Generate architecture from policy weights (for RL)
1040    fn generate_architecture_from_policy(
1041        &mut self,
1042        _policy_weights: &HashMap<String, Float>,
1043    ) -> Result<NeuralArchitecture, Box<dyn std::error::Error>> {
1044        // Simplified policy-based generation
1045        // In practice, this would use learned policy weights
1046        self.generate_random_architecture()
1047    }
1048
1049    /// Update policy weights based on reward (for RL)
1050    fn update_policy_weights(
1051        &mut self,
1052        policy_weights: &mut HashMap<String, Float>,
1053        architecture: &NeuralArchitecture,
1054        reward: Float,
1055        learning_rate: Float,
1056    ) {
1057        // Simplified policy update
1058        // In practice, this would update weights based on architecture features
1059        let key = format!("layers_{}", architecture.num_layers);
1060        let current_weight = policy_weights.get(&key).unwrap_or(&0.0);
1061        policy_weights.insert(key, current_weight + learning_rate * reward);
1062    }
1063
1064    /// Initialize architecture parameters (for GDAS)
1065    fn initialize_architecture_parameters(
1066        &mut self,
1067    ) -> Result<HashMap<String, Float>, Box<dyn std::error::Error>> {
1068        let mut params = HashMap::new();
1069        params.insert("layer_weight".to_string(), 0.5);
1070        params.insert("activation_weight".to_string(), 0.5);
1071        params.insert("dropout_weight".to_string(), 0.5);
1072        Ok(params)
1073    }
1074
1075    /// Sample architecture from parameters (for GDAS)
1076    fn sample_architecture_from_params(
1077        &mut self,
1078        _params: &HashMap<String, Float>,
1079    ) -> Result<NeuralArchitecture, Box<dyn std::error::Error>> {
1080        // Simplified sampling based on parameters
1081        self.generate_random_architecture()
1082    }
1083
1084    /// Update architecture parameters (for GDAS)
1085    fn update_architecture_parameters(
1086        &mut self,
1087        params: &mut HashMap<String, Float>,
1088        evaluation: &ArchitectureEvaluation,
1089        learning_rate: Float,
1090    ) {
1091        // Simplified parameter update
1092        for (_key, value) in params.iter_mut() {
1093            *value += learning_rate * evaluation.validation_score * 0.01;
1094        }
1095    }
1096
1097    /// Generate architecture with complexity constraint
1098    fn generate_architecture_with_complexity(
1099        &mut self,
1100        max_complexity: Float,
1101    ) -> Result<NeuralArchitecture, Box<dyn std::error::Error>> {
1102        let mut architecture = self.generate_random_architecture()?;
1103
1104        // Adjust architecture to meet complexity constraint
1105        while architecture.complexity_score > max_complexity {
1106            if architecture.num_layers > 1 {
1107                architecture.num_layers -= 1;
1108                architecture.layer_sizes.pop();
1109                architecture.activations.pop();
1110                architecture.dropout_rates.pop();
1111                architecture.batch_norm.pop();
1112            }
1113
1114            // Reduce layer sizes
1115            for size in &mut architecture.layer_sizes {
1116                *size = ((*size as Float) * 0.8) as usize;
1117            }
1118
1119            architecture.complexity_score = self.calculate_complexity_score(
1120                &architecture.layer_sizes,
1121                &architecture.skip_connections,
1122            );
1123        }
1124
1125        Ok(architecture)
1126    }
1127
1128    /// Select next architecture using Bayesian optimization
1129    fn select_next_architecture_bayesian(
1130        &mut self,
1131        evaluated_architectures: &[(NeuralArchitecture, Float)],
1132    ) -> Result<NeuralArchitecture, Box<dyn std::error::Error>> {
1133        // Simplified Bayesian optimization - in practice would use GP surrogate model
1134        // For now, generate random architecture with slight bias towards better performing ones
1135        let mut architecture = self.generate_random_architecture()?;
1136
1137        // Bias towards architectures similar to best performing ones
1138        if !evaluated_architectures.is_empty() {
1139            let best_score = evaluated_architectures
1140                .iter()
1141                .map(|(_, score)| *score)
1142                .fold(Float::NEG_INFINITY, |a, b| a.max(b));
1143
1144            for (arch, score) in evaluated_architectures {
1145                if *score > best_score * 0.9 {
1146                    // Slightly bias towards similar architectures
1147                    if self.rng.gen_bool(0.3) {
1148                        architecture.num_layers = arch.num_layers;
1149                    }
1150                    break;
1151                }
1152            }
1153        }
1154
1155        Ok(architecture)
1156    }
1157}
1158
1159impl NeuralArchitecture {
1160    /// Calculate the number of parameters in the architecture
1161    pub fn parameter_count(&self) -> usize {
1162        if self.layer_sizes.is_empty() {
1163            return 0;
1164        }
1165
1166        let mut total = 0;
1167        for i in 0..self.layer_sizes.len() - 1 {
1168            total += self.layer_sizes[i] * self.layer_sizes[i + 1];
1169        }
1170        total
1171    }
1172
1173    /// Calculate estimated FLOPs for the architecture
1174    pub fn estimated_flops(&self) -> usize {
1175        self.parameter_count() * 2 // Rough estimate
1176    }
1177
1178    /// Get architecture summary
1179    pub fn summary(&self) -> String {
1180        format!(
1181            "Architecture: {} layers, {} parameters, complexity: {:.2}",
1182            self.num_layers,
1183            self.parameter_count(),
1184            self.complexity_score
1185        )
1186    }
1187}
1188
1189#[allow(non_snake_case)]
1190#[cfg(test)]
1191mod tests {
1192    use super::*;
1193
1194    #[test]
1195    fn test_nas_optimizer_creation() {
1196        let config = NASConfig::default();
1197        let optimizer = NASOptimizer::new(config);
1198        assert!(optimizer.config.search_space.layer_count_range.0 > 0);
1199    }
1200
1201    #[test]
1202    fn test_random_architecture_generation() {
1203        let config = NASConfig::default();
1204        let mut optimizer = NASOptimizer::new(config);
1205
1206        let architecture = optimizer.generate_random_architecture().unwrap();
1207        assert!(architecture.num_layers > 0);
1208        assert!(architecture.layer_sizes.len() == architecture.num_layers);
1209        assert!(architecture.activations.len() == architecture.num_layers);
1210    }
1211
1212    #[test]
1213    fn test_architecture_complexity_calculation() {
1214        let config = NASConfig::default();
1215        let optimizer = NASOptimizer::new(config);
1216
1217        let layer_sizes = vec![64, 128, 32];
1218        let skip_connections = vec![(0, 2)];
1219
1220        let complexity = optimizer.calculate_complexity_score(&layer_sizes, &skip_connections);
1221        assert!(complexity > 0.0);
1222    }
1223
1224    #[test]
1225    fn test_architecture_parameter_count() {
1226        let architecture = NeuralArchitecture {
1227            num_layers: 3,
1228            layer_sizes: vec![64, 128, 32],
1229            activations: vec![
1230                "relu".to_string(),
1231                "relu".to_string(),
1232                "sigmoid".to_string(),
1233            ],
1234            dropout_rates: vec![0.1, 0.2, 0.0],
1235            batch_norm: vec![true, true, false],
1236            skip_connections: vec![],
1237            complexity_score: 224.0,
1238        };
1239
1240        let param_count = architecture.parameter_count();
1241        assert_eq!(param_count, 64 * 128 + 128 * 32); // 8192 + 4096 = 12288
1242    }
1243
1244    #[test]
1245    fn test_nas_random_search() {
1246        let config = NASConfig {
1247            strategy: NASStrategy::RandomSearch {
1248                n_trials: 5,
1249                max_depth: 5,
1250            },
1251            ..Default::default()
1252        };
1253        let mut optimizer = NASOptimizer::new(config);
1254
1255        let evaluation_fn = |arch: &NeuralArchitecture| -> Result<ArchitectureEvaluation, Box<dyn std::error::Error>> {
1256            Ok(ArchitectureEvaluation {
1257                architecture: arch.clone(),
1258                validation_score: 0.8,
1259                training_time: 10.0,
1260                parameters_count: arch.parameter_count(),
1261                flops: arch.estimated_flops(),
1262                memory_usage: 100.0,
1263            })
1264        };
1265
1266        let result = optimizer.search(evaluation_fn).unwrap();
1267        assert!(result.best_score > 0.0);
1268        assert_eq!(result.architectures_evaluated, 5);
1269    }
1270}