sklears_model_selection/
meta_learning.rs

1//! Meta-Learning for Hyperparameter Initialization
2//!
3//! This module provides meta-learning capabilities for intelligent hyperparameter initialization
4//! based on historical optimization data, dataset characteristics, and algorithm performance patterns.
5//! It learns from past optimization experiences to provide better starting points for new optimization tasks.
6
7use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::random::rngs::StdRng;
9use scirs2_core::random::SeedableRng;
10use serde::{Deserialize, Serialize};
11use sklears_core::types::Float;
12use std::collections::HashMap;
13
14/// Meta-learning strategies for hyperparameter initialization
15#[derive(Debug, Clone)]
16pub enum MetaLearningStrategy {
17    /// Similarity-based recommendation using dataset characteristics
18    SimilarityBased {
19        similarity_metric: SimilarityMetric,
20
21        k_neighbors: usize,
22
23        weight_by_distance: bool,
24    },
25    /// Model-based meta-learning using surrogate models
26    ModelBased {
27        surrogate_model: SurrogateModel,
28
29        update_frequency: usize,
30    },
31    /// Gradient-based meta-learning (MAML-style)
32    GradientBased {
33        meta_learning_rate: Float,
34        adaptation_steps: usize,
35        inner_learning_rate: Float,
36    },
37    /// Bayesian meta-learning using hierarchical models
38    BayesianMeta {
39        prior_strength: Float,
40        hierarchical_levels: usize,
41    },
42    /// Transfer learning from similar tasks
43    TransferLearning {
44        transfer_method: TransferMethod,
45        domain_adaptation: bool,
46    },
47    /// Ensemble of meta-learners
48    EnsembleMeta {
49        strategies: Vec<MetaLearningStrategy>,
50        combination_method: CombinationMethod,
51    },
52}
53
54/// Similarity metrics for dataset comparison
55#[derive(Debug, Clone)]
56pub enum SimilarityMetric {
57    /// Cosine similarity between dataset statistics
58    Cosine,
59    /// Euclidean distance between features
60    Euclidean,
61    /// Manhattan distance
62    Manhattan,
63    /// Pearson correlation
64    Correlation,
65    /// Jensen-Shannon divergence
66    JensenShannon,
67    /// Learned similarity using neural networks
68    Learned,
69}
70
71/// Surrogate models for meta-learning
72#[derive(Debug, Clone)]
73pub enum SurrogateModel {
74    /// Random Forest for hyperparameter prediction
75    RandomForest {
76        n_trees: usize,
77
78        max_depth: Option<usize>,
79    },
80    /// Gaussian Process for uncertainty quantification
81    GaussianProcess { kernel_type: String },
82    /// Neural Network for complex patterns
83    NeuralNetwork { hidden_layers: Vec<usize> },
84    /// Linear model for simple relationships
85    LinearRegression { regularization: Float },
86}
87
88/// Transfer learning methods
89#[derive(Debug, Clone)]
90pub enum TransferMethod {
91    /// Direct parameter transfer
92    DirectTransfer,
93    /// Feature-based transfer
94    FeatureTransfer,
95    /// Model-based transfer
96    ModelTransfer,
97    /// Instance-based transfer
98    InstanceTransfer,
99}
100
101/// Combination methods for ensemble meta-learning
102#[derive(Debug, Clone)]
103pub enum CombinationMethod {
104    /// Average predictions
105    Average,
106    /// Weighted average by performance
107    WeightedAverage,
108    /// Stacking with meta-model
109    Stacking,
110    /// Bayesian model averaging
111    BayesianAveraging,
112}
113
114/// Dataset characteristics for meta-learning
115#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct DatasetCharacteristics {
117    pub n_samples: usize,
118    pub n_features: usize,
119    pub n_classes: Option<usize>,
120    pub class_balance: Vec<Float>,
121    pub feature_types: Vec<FeatureType>,
122    pub statistical_measures: StatisticalMeasures,
123    pub complexity_measures: ComplexityMeasures,
124    pub domain_specific: HashMap<String, Float>,
125}
126
127/// Feature types
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub enum FeatureType {
130    /// Numerical
131    Numerical,
132    /// Categorical
133    Categorical,
134    /// Ordinal
135    Ordinal,
136    /// Text
137    Text,
138    /// Image
139    Image,
140    /// TimeSeries
141    TimeSeries,
142}
143
144/// Statistical measures of the dataset
145#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct StatisticalMeasures {
147    pub mean_values: Vec<Float>,
148    pub std_values: Vec<Float>,
149    pub skewness: Vec<Float>,
150    pub kurtosis: Vec<Float>,
151    pub correlation_matrix: Option<Array2<Float>>,
152    pub mutual_information: Option<Vec<Float>>,
153}
154
155/// Complexity measures of the dataset
156#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct ComplexityMeasures {
158    pub fisher_discriminant_ratio: Float,
159    pub volume_of_overlap: Float,
160    pub feature_efficiency: Float,
161    pub collective_feature_efficiency: Float,
162    pub entropy: Float,
163    pub class_probability_max: Float,
164}
165
166/// Historical optimization record
167#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct OptimizationRecord {
169    pub dataset_id: String,
170    pub algorithm_name: String,
171    pub dataset_characteristics: DatasetCharacteristics,
172    pub hyperparameters: HashMap<String, ParameterValue>,
173    pub performance_score: Float,
174    pub optimization_time: Float,
175    pub convergence_iterations: usize,
176    pub validation_method: String,
177    pub timestamp: u64,
178}
179
180/// Parameter value types
181#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
182pub enum ParameterValue {
183    /// Float
184    Float(Float),
185    /// Integer
186    Integer(i64),
187    /// Boolean
188    Boolean(bool),
189    /// String
190    String(String),
191    /// Array
192    Array(Vec<Float>),
193}
194
195/// Meta-learning recommendations
196#[derive(Debug, Clone)]
197pub struct MetaLearningRecommendation {
198    pub recommended_hyperparameters: HashMap<String, ParameterValue>,
199    pub confidence_scores: HashMap<String, Float>,
200    pub expected_performance: Float,
201    pub expected_runtime: Float,
202    pub similar_datasets: Vec<String>,
203    pub recommendation_source: String,
204    pub uncertainty_estimate: Float,
205}
206
207/// Meta-learning configuration
208#[derive(Debug, Clone)]
209pub struct MetaLearningConfig {
210    pub strategy: MetaLearningStrategy,
211    pub min_historical_records: usize,
212    pub max_similarity_distance: Float,
213    pub confidence_threshold: Float,
214    pub update_interval: usize,
215    pub cache_size: usize,
216    pub random_state: Option<u64>,
217}
218
219/// Meta-learning engine
220#[derive(Debug)]
221pub struct MetaLearningEngine {
222    config: MetaLearningConfig,
223    historical_records: Vec<OptimizationRecord>,
224    dataset_similarity_cache: HashMap<String, Vec<(String, Float)>>,
225    surrogate_models: HashMap<String, Box<dyn SurrogateModelTrait>>,
226    rng: StdRng,
227}
228
229/// Trait for surrogate models
230trait SurrogateModelTrait: std::fmt::Debug {
231    fn fit(
232        &mut self,
233        features: &Array2<Float>,
234        targets: &Array1<Float>,
235    ) -> Result<(), Box<dyn std::error::Error>>;
236    fn predict(
237        &self,
238        features: &Array2<Float>,
239    ) -> Result<Array1<Float>, Box<dyn std::error::Error>>;
240    fn predict_with_uncertainty(
241        &self,
242        features: &Array2<Float>,
243    ) -> Result<(Array1<Float>, Array1<Float>), Box<dyn std::error::Error>>;
244}
245
246/// Simple Random Forest surrogate model implementation
247#[derive(Debug)]
248struct RandomForestSurrogate {
249    n_trees: usize,
250    max_depth: Option<usize>,
251    models: Vec<SimpleTree>,
252}
253
254/// Simple decision tree for surrogate model
255#[derive(Debug, Clone)]
256struct SimpleTree {
257    feature_idx: Option<usize>,
258    threshold: Option<Float>,
259    left: Option<Box<SimpleTree>>,
260    right: Option<Box<SimpleTree>>,
261    prediction: Option<Float>,
262}
263
264impl Default for MetaLearningConfig {
265    fn default() -> Self {
266        Self {
267            strategy: MetaLearningStrategy::SimilarityBased {
268                similarity_metric: SimilarityMetric::Cosine,
269                k_neighbors: 5,
270                weight_by_distance: true,
271            },
272            min_historical_records: 10,
273            max_similarity_distance: 0.8,
274            confidence_threshold: 0.6,
275            update_interval: 100,
276            cache_size: 1000,
277            random_state: None,
278        }
279    }
280}
281
282impl MetaLearningEngine {
283    /// Create a new meta-learning engine
284    pub fn new(config: MetaLearningConfig) -> Self {
285        let rng = match config.random_state {
286            Some(seed) => StdRng::seed_from_u64(seed),
287            None => {
288                use scirs2_core::random::thread_rng;
289                StdRng::from_rng(&mut thread_rng())
290            }
291        };
292
293        Self {
294            config,
295            historical_records: Vec::new(),
296            dataset_similarity_cache: HashMap::new(),
297            surrogate_models: HashMap::new(),
298            rng,
299        }
300    }
301
302    /// Load historical optimization records
303    pub fn load_historical_records(&mut self, records: Vec<OptimizationRecord>) {
304        self.historical_records.extend(records);
305        self.update_surrogate_models().unwrap_or_else(|e| {
306            eprintln!("Warning: Failed to update surrogate models: {}", e);
307        });
308    }
309
310    /// Add a new optimization record
311    pub fn add_record(&mut self, record: OptimizationRecord) {
312        self.historical_records.push(record);
313
314        // Update models periodically
315        if self.historical_records.len() % self.config.update_interval == 0 {
316            self.update_surrogate_models().unwrap_or_else(|e| {
317                eprintln!("Warning: Failed to update surrogate models: {}", e);
318            });
319        }
320    }
321
322    /// Get hyperparameter recommendations for a new dataset
323    pub fn recommend_hyperparameters(
324        &mut self,
325        dataset_characteristics: &DatasetCharacteristics,
326        algorithm_name: &str,
327    ) -> Result<MetaLearningRecommendation, Box<dyn std::error::Error>> {
328        if self.historical_records.len() < self.config.min_historical_records {
329            return Err("Insufficient historical data for meta-learning".into());
330        }
331
332        match &self.config.strategy {
333            MetaLearningStrategy::SimilarityBased { .. } => {
334                self.similarity_based_recommendation(dataset_characteristics, algorithm_name)
335            }
336            MetaLearningStrategy::ModelBased { .. } => {
337                self.model_based_recommendation(dataset_characteristics, algorithm_name)
338            }
339            MetaLearningStrategy::GradientBased { .. } => {
340                self.gradient_based_recommendation(dataset_characteristics, algorithm_name)
341            }
342            MetaLearningStrategy::BayesianMeta { .. } => {
343                self.bayesian_meta_recommendation(dataset_characteristics, algorithm_name)
344            }
345            MetaLearningStrategy::TransferLearning { .. } => {
346                self.transfer_learning_recommendation(dataset_characteristics, algorithm_name)
347            }
348            MetaLearningStrategy::EnsembleMeta { .. } => {
349                self.ensemble_meta_recommendation(dataset_characteristics, algorithm_name)
350            }
351        }
352    }
353
354    /// Similarity-based recommendation
355    fn similarity_based_recommendation(
356        &mut self,
357        dataset_characteristics: &DatasetCharacteristics,
358        algorithm_name: &str,
359    ) -> Result<MetaLearningRecommendation, Box<dyn std::error::Error>> {
360        let (similarity_metric, k_neighbors, weight_by_distance) = match &self.config.strategy {
361            MetaLearningStrategy::SimilarityBased {
362                similarity_metric,
363                k_neighbors,
364                weight_by_distance,
365            } => (similarity_metric, *k_neighbors, *weight_by_distance),
366            _ => unreachable!(),
367        };
368
369        // Find similar datasets
370        let mut similarities = Vec::new();
371        for record in &self.historical_records {
372            if record.algorithm_name == algorithm_name {
373                let similarity = self.calculate_similarity(
374                    dataset_characteristics,
375                    &record.dataset_characteristics,
376                    similarity_metric,
377                )?;
378                similarities.push((record, similarity));
379            }
380        }
381
382        // Sort by similarity and take top k
383        similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
384        similarities.truncate(k_neighbors);
385
386        if similarities.is_empty() {
387            return Err("No similar datasets found".into());
388        }
389
390        // Aggregate hyperparameters from similar datasets
391        let mut aggregated_hyperparameters = HashMap::new();
392        let mut confidence_scores = HashMap::new();
393        let mut expected_performance = 0.0;
394        let mut expected_runtime = 0.0;
395        let mut total_weight = 0.0;
396
397        for (record, similarity) in &similarities {
398            let weight = if weight_by_distance { *similarity } else { 1.0 };
399            total_weight += weight;
400
401            expected_performance += record.performance_score * weight;
402            expected_runtime += record.optimization_time * weight;
403
404            for (param_name, param_value) in &record.hyperparameters {
405                match param_value {
406                    ParameterValue::Float(val) => {
407                        let entry = aggregated_hyperparameters
408                            .entry(param_name.clone())
409                            .or_insert_with(|| (0.0, 0.0)); // (sum, weight_sum)
410                        entry.0 += val * weight;
411                        entry.1 += weight;
412                    }
413                    ParameterValue::Integer(val) => {
414                        let entry = aggregated_hyperparameters
415                            .entry(param_name.clone())
416                            .or_insert_with(|| (0.0, 0.0));
417                        entry.0 += *val as Float * weight;
418                        entry.1 += weight;
419                    }
420                    _ => {
421                        // For non-numeric parameters, use the most common value
422                        // Simplified implementation
423                    }
424                }
425
426                confidence_scores.insert(param_name.clone(), *similarity);
427            }
428        }
429
430        // Convert aggregated values to recommendations
431        let mut recommended_hyperparameters = HashMap::new();
432        for (param_name, (sum, weight_sum)) in aggregated_hyperparameters {
433            let avg_value = sum / weight_sum;
434            recommended_hyperparameters.insert(param_name, ParameterValue::Float(avg_value));
435        }
436
437        expected_performance /= total_weight;
438        expected_runtime /= total_weight;
439
440        let similar_datasets = similarities
441            .iter()
442            .map(|(record, _)| record.dataset_id.clone())
443            .collect();
444
445        let uncertainty_estimate = 1.0
446            - similarities.iter().map(|(_, sim)| sim).sum::<Float>() / similarities.len() as Float;
447
448        Ok(MetaLearningRecommendation {
449            recommended_hyperparameters,
450            confidence_scores,
451            expected_performance,
452            expected_runtime,
453            similar_datasets,
454            recommendation_source: "SimilarityBased".to_string(),
455            uncertainty_estimate,
456        })
457    }
458
459    /// Model-based recommendation using surrogate models
460    fn model_based_recommendation(
461        &mut self,
462        dataset_characteristics: &DatasetCharacteristics,
463        algorithm_name: &str,
464    ) -> Result<MetaLearningRecommendation, Box<dyn std::error::Error>> {
465        let model_key = format!("{}_{}", algorithm_name, "hyperparams");
466
467        if let Some(model) = self.surrogate_models.get(&model_key) {
468            let features = self.extract_features(dataset_characteristics)?;
469            let features_2d = Array2::from_shape_vec((1, features.len()), features.to_vec())?;
470
471            let (predictions, uncertainties) = model.predict_with_uncertainty(&features_2d)?;
472
473            // Convert predictions to hyperparameter recommendations
474            let mut recommended_hyperparameters = HashMap::new();
475            let mut confidence_scores = HashMap::new();
476
477            // Simplified: assume first prediction is for a specific hyperparameter
478            recommended_hyperparameters.insert(
479                "learning_rate".to_string(),
480                ParameterValue::Float(predictions[0]),
481            );
482            confidence_scores.insert("learning_rate".to_string(), 1.0 - uncertainties[0]);
483
484            Ok(MetaLearningRecommendation {
485                recommended_hyperparameters,
486                confidence_scores,
487                expected_performance: predictions[0],
488                expected_runtime: 100.0, // Placeholder
489                similar_datasets: vec![],
490                recommendation_source: "ModelBased".to_string(),
491                uncertainty_estimate: uncertainties[0],
492            })
493        } else {
494            Err("Surrogate model not available".into())
495        }
496    }
497
498    /// Gradient-based meta-learning recommendation
499    fn gradient_based_recommendation(
500        &mut self,
501        _dataset_characteristics: &DatasetCharacteristics,
502        algorithm_name: &str,
503    ) -> Result<MetaLearningRecommendation, Box<dyn std::error::Error>> {
504        // Simplified gradient-based meta-learning
505        // In practice, this would implement MAML or similar algorithms
506
507        let similar_records: Vec<&OptimizationRecord> = self
508            .historical_records
509            .iter()
510            .filter(|r| r.algorithm_name == algorithm_name)
511            .collect();
512
513        if similar_records.is_empty() {
514            return Err("No historical records for algorithm".into());
515        }
516
517        // Simulate gradient-based adaptation
518        let mut adapted_hyperparameters = HashMap::new();
519        let mut confidence_scores = HashMap::new();
520
521        // Use the best performing record as starting point
522        let best_record = similar_records
523            .iter()
524            .max_by(|a, b| {
525                a.performance_score
526                    .partial_cmp(&b.performance_score)
527                    .unwrap()
528            })
529            .unwrap();
530
531        for (param_name, param_value) in &best_record.hyperparameters {
532            adapted_hyperparameters.insert(param_name.clone(), param_value.clone());
533            confidence_scores.insert(param_name.clone(), 0.8); // High confidence from gradient adaptation
534        }
535
536        Ok(MetaLearningRecommendation {
537            recommended_hyperparameters: adapted_hyperparameters,
538            confidence_scores,
539            expected_performance: best_record.performance_score,
540            expected_runtime: best_record.optimization_time,
541            similar_datasets: vec![best_record.dataset_id.clone()],
542            recommendation_source: "GradientBased".to_string(),
543            uncertainty_estimate: 0.2,
544        })
545    }
546
547    /// Bayesian meta-learning recommendation
548    fn bayesian_meta_recommendation(
549        &mut self,
550        _dataset_characteristics: &DatasetCharacteristics,
551        algorithm_name: &str,
552    ) -> Result<MetaLearningRecommendation, Box<dyn std::error::Error>> {
553        // Simplified Bayesian meta-learning
554        let relevant_records: Vec<&OptimizationRecord> = self
555            .historical_records
556            .iter()
557            .filter(|r| r.algorithm_name == algorithm_name)
558            .collect();
559
560        if relevant_records.is_empty() {
561            return Err("No historical records for algorithm".into());
562        }
563
564        // Bayesian inference for hyperparameter distributions
565        let mut hyperparameter_distributions = HashMap::new();
566        let mut confidence_scores = HashMap::new();
567
568        for record in &relevant_records {
569            for (param_name, param_value) in &record.hyperparameters {
570                if let ParameterValue::Float(val) = param_value {
571                    let entry = hyperparameter_distributions
572                        .entry(param_name.clone())
573                        .or_insert_with(Vec::new);
574                    entry.push(*val);
575                }
576            }
577        }
578
579        let mut recommended_hyperparameters = HashMap::new();
580        for (param_name, values) in hyperparameter_distributions {
581            let mean = values.iter().sum::<Float>() / values.len() as Float;
582            let variance =
583                values.iter().map(|v| (v - mean).powi(2)).sum::<Float>() / values.len() as Float;
584
585            recommended_hyperparameters.insert(param_name.clone(), ParameterValue::Float(mean));
586            confidence_scores.insert(param_name, 1.0 / (1.0 + variance)); // Higher confidence for lower variance
587        }
588
589        let avg_performance = relevant_records
590            .iter()
591            .map(|r| r.performance_score)
592            .sum::<Float>()
593            / relevant_records.len() as Float;
594
595        Ok(MetaLearningRecommendation {
596            recommended_hyperparameters,
597            confidence_scores,
598            expected_performance: avg_performance,
599            expected_runtime: 100.0, // Placeholder
600            similar_datasets: relevant_records
601                .iter()
602                .map(|r| r.dataset_id.clone())
603                .collect(),
604            recommendation_source: "BayesianMeta".to_string(),
605            uncertainty_estimate: 0.3,
606        })
607    }
608
609    /// Transfer learning recommendation
610    fn transfer_learning_recommendation(
611        &mut self,
612        dataset_characteristics: &DatasetCharacteristics,
613        algorithm_name: &str,
614    ) -> Result<MetaLearningRecommendation, Box<dyn std::error::Error>> {
615        // Find the most similar domain
616        let mut best_similarity = 0.0;
617        let mut best_record = None;
618
619        for record in &self.historical_records {
620            if record.algorithm_name == algorithm_name {
621                let similarity = self.calculate_similarity(
622                    dataset_characteristics,
623                    &record.dataset_characteristics,
624                    &SimilarityMetric::Cosine,
625                )?;
626
627                if similarity > best_similarity {
628                    best_similarity = similarity;
629                    best_record = Some(record);
630                }
631            }
632        }
633
634        if let Some(record) = best_record {
635            let mut confidence_scores = HashMap::new();
636            for param_name in record.hyperparameters.keys() {
637                confidence_scores.insert(param_name.clone(), best_similarity);
638            }
639
640            Ok(MetaLearningRecommendation {
641                recommended_hyperparameters: record.hyperparameters.clone(),
642                confidence_scores,
643                expected_performance: record.performance_score * best_similarity,
644                expected_runtime: record.optimization_time,
645                similar_datasets: vec![record.dataset_id.clone()],
646                recommendation_source: "TransferLearning".to_string(),
647                uncertainty_estimate: 1.0 - best_similarity,
648            })
649        } else {
650            Err("No suitable source domain found for transfer learning".into())
651        }
652    }
653
654    /// Ensemble meta-learning recommendation
655    fn ensemble_meta_recommendation(
656        &mut self,
657        dataset_characteristics: &DatasetCharacteristics,
658        algorithm_name: &str,
659    ) -> Result<MetaLearningRecommendation, Box<dyn std::error::Error>> {
660        let (strategies, combination_method) = match &self.config.strategy {
661            MetaLearningStrategy::EnsembleMeta {
662                strategies,
663                combination_method,
664            } => (strategies, combination_method),
665            _ => unreachable!(),
666        };
667
668        let mut recommendations = Vec::new();
669
670        // Get recommendations from each strategy
671        for strategy in strategies {
672            let mut temp_config = self.config.clone();
673            temp_config.strategy = strategy.clone();
674            let mut temp_engine = MetaLearningEngine::new(temp_config);
675            temp_engine.historical_records = self.historical_records.clone();
676
677            if let Ok(rec) =
678                temp_engine.recommend_hyperparameters(dataset_characteristics, algorithm_name)
679            {
680                recommendations.push(rec);
681            }
682        }
683
684        if recommendations.is_empty() {
685            return Err("No recommendations from ensemble strategies".into());
686        }
687
688        // Combine recommendations
689        match combination_method {
690            CombinationMethod::Average => self.average_recommendations(recommendations),
691            CombinationMethod::WeightedAverage => {
692                self.weighted_average_recommendations(recommendations)
693            }
694            _ => {
695                // Default to average for other methods
696                self.average_recommendations(recommendations)
697            }
698        }
699    }
700
701    /// Calculate similarity between datasets
702    fn calculate_similarity(
703        &self,
704        dataset1: &DatasetCharacteristics,
705        dataset2: &DatasetCharacteristics,
706        metric: &SimilarityMetric,
707    ) -> Result<Float, Box<dyn std::error::Error>> {
708        let features1 = self.extract_features(dataset1)?;
709        let features2 = self.extract_features(dataset2)?;
710
711        match metric {
712            SimilarityMetric::Cosine => {
713                let dot_product = features1
714                    .iter()
715                    .zip(features2.iter())
716                    .map(|(a, b)| a * b)
717                    .sum::<Float>();
718                let norm1 = (features1.iter().map(|x| x * x).sum::<Float>()).sqrt();
719                let norm2 = (features2.iter().map(|x| x * x).sum::<Float>()).sqrt();
720                Ok(dot_product / (norm1 * norm2))
721            }
722            SimilarityMetric::Euclidean => {
723                let distance = features1
724                    .iter()
725                    .zip(features2.iter())
726                    .map(|(a, b)| (a - b).powi(2))
727                    .sum::<Float>()
728                    .sqrt();
729                Ok(1.0 / (1.0 + distance))
730            }
731            SimilarityMetric::Manhattan => {
732                let distance = features1
733                    .iter()
734                    .zip(features2.iter())
735                    .map(|(a, b)| (a - b).abs())
736                    .sum::<Float>();
737                Ok(1.0 / (1.0 + distance))
738            }
739            _ => {
740                // Default to cosine similarity
741                let dot_product = features1
742                    .iter()
743                    .zip(features2.iter())
744                    .map(|(a, b)| a * b)
745                    .sum::<Float>();
746                let norm1 = (features1.iter().map(|x| x * x).sum::<Float>()).sqrt();
747                let norm2 = (features2.iter().map(|x| x * x).sum::<Float>()).sqrt();
748                Ok(dot_product / (norm1 * norm2))
749            }
750        }
751    }
752
753    /// Extract feature vector from dataset characteristics
754    fn extract_features(
755        &self,
756        characteristics: &DatasetCharacteristics,
757    ) -> Result<Array1<Float>, Box<dyn std::error::Error>> {
758        let mut features = Vec::new();
759
760        // Basic dataset statistics
761        features.push(characteristics.n_samples as Float);
762        features.push(characteristics.n_features as Float);
763        features.push(characteristics.n_classes.unwrap_or(0) as Float);
764
765        // Statistical measures
766        if !characteristics.statistical_measures.mean_values.is_empty() {
767            features.extend(&characteristics.statistical_measures.mean_values);
768        }
769
770        // Complexity measures
771        features.push(
772            characteristics
773                .complexity_measures
774                .fisher_discriminant_ratio,
775        );
776        features.push(characteristics.complexity_measures.volume_of_overlap);
777        features.push(characteristics.complexity_measures.feature_efficiency);
778        features.push(characteristics.complexity_measures.entropy);
779
780        Ok(Array1::from_vec(features))
781    }
782
783    /// Update surrogate models with new data
784    fn update_surrogate_models(&mut self) -> Result<(), Box<dyn std::error::Error>> {
785        // Group records by algorithm
786        let mut algorithm_groups: HashMap<String, Vec<&OptimizationRecord>> = HashMap::new();
787
788        for record in &self.historical_records {
789            algorithm_groups
790                .entry(record.algorithm_name.clone())
791                .or_default()
792                .push(record);
793        }
794
795        // Train surrogate models for each algorithm
796        for (algorithm_name, records) in algorithm_groups {
797            if records.len() >= 5 {
798                // Minimum records for training
799                let model_key = format!("{}_{}", algorithm_name, "hyperparams");
800
801                // Extract features and targets
802                let mut features_vec = Vec::new();
803                let mut targets = Vec::new();
804
805                for record in &records {
806                    let features = self.extract_features(&record.dataset_characteristics)?;
807                    features_vec.extend(features.to_vec());
808                    targets.push(record.performance_score);
809                }
810
811                let n_features = self
812                    .extract_features(&records[0].dataset_characteristics)?
813                    .len();
814                let features_2d =
815                    Array2::from_shape_vec((records.len(), n_features), features_vec)?;
816                let targets_1d = Array1::from_vec(targets);
817
818                // Create and train surrogate model
819                let mut surrogate = Box::new(RandomForestSurrogate::new(10, Some(5)));
820                surrogate.fit(&features_2d, &targets_1d)?;
821
822                self.surrogate_models.insert(model_key, surrogate);
823            }
824        }
825
826        Ok(())
827    }
828
829    /// Average multiple recommendations
830    fn average_recommendations(
831        &self,
832        recommendations: Vec<MetaLearningRecommendation>,
833    ) -> Result<MetaLearningRecommendation, Box<dyn std::error::Error>> {
834        if recommendations.is_empty() {
835            return Err("No recommendations to average".into());
836        }
837
838        let mut aggregated_hyperparameters = HashMap::new();
839        let mut confidence_scores = HashMap::new();
840        let mut expected_performance = 0.0;
841        let mut expected_runtime = 0.0;
842        let mut uncertainty_estimate = 0.0;
843
844        let n_recommendations = recommendations.len() as Float;
845
846        for rec in &recommendations {
847            expected_performance += rec.expected_performance;
848            expected_runtime += rec.expected_runtime;
849            uncertainty_estimate += rec.uncertainty_estimate;
850
851            for (param_name, param_value) in &rec.recommended_hyperparameters {
852                if let ParameterValue::Float(val) = param_value {
853                    *aggregated_hyperparameters
854                        .entry(param_name.clone())
855                        .or_insert(0.0) += val;
856                }
857            }
858
859            for (param_name, confidence) in &rec.confidence_scores {
860                *confidence_scores.entry(param_name.clone()).or_insert(0.0) += confidence;
861            }
862        }
863
864        // Average the values
865        let mut recommended_hyperparameters = HashMap::new();
866        for (param_name, sum) in aggregated_hyperparameters {
867            recommended_hyperparameters.insert(
868                param_name.clone(),
869                ParameterValue::Float(sum / n_recommendations),
870            );
871            if let Some(conf_sum) = confidence_scores.get_mut(&param_name) {
872                *conf_sum /= n_recommendations;
873            }
874        }
875
876        Ok(MetaLearningRecommendation {
877            recommended_hyperparameters,
878            confidence_scores,
879            expected_performance: expected_performance / n_recommendations,
880            expected_runtime: expected_runtime / n_recommendations,
881            similar_datasets: vec![], // Combine all similar datasets if needed
882            recommendation_source: "EnsembleAverage".to_string(),
883            uncertainty_estimate: uncertainty_estimate / n_recommendations,
884        })
885    }
886
887    /// Weighted average of recommendations
888    fn weighted_average_recommendations(
889        &self,
890        recommendations: Vec<MetaLearningRecommendation>,
891    ) -> Result<MetaLearningRecommendation, Box<dyn std::error::Error>> {
892        if recommendations.is_empty() {
893            return Err("No recommendations to average".into());
894        }
895
896        // Use confidence as weights (inverse of uncertainty)
897        let weights: Vec<Float> = recommendations
898            .iter()
899            .map(|r| 1.0 - r.uncertainty_estimate)
900            .collect();
901
902        let total_weight: Float = weights.iter().sum();
903
904        let mut aggregated_hyperparameters = HashMap::new();
905        let mut confidence_scores = HashMap::new();
906        let mut expected_performance = 0.0;
907        let mut expected_runtime = 0.0;
908        let mut uncertainty_estimate = 0.0;
909
910        for (i, rec) in recommendations.iter().enumerate() {
911            let weight = weights[i] / total_weight;
912
913            expected_performance += rec.expected_performance * weight;
914            expected_runtime += rec.expected_runtime * weight;
915            uncertainty_estimate += rec.uncertainty_estimate * weight;
916
917            for (param_name, param_value) in &rec.recommended_hyperparameters {
918                if let ParameterValue::Float(val) = param_value {
919                    *aggregated_hyperparameters
920                        .entry(param_name.clone())
921                        .or_insert(0.0) += val * weight;
922                }
923            }
924
925            for (param_name, confidence) in &rec.confidence_scores {
926                *confidence_scores.entry(param_name.clone()).or_insert(0.0) += confidence * weight;
927            }
928        }
929
930        let mut recommended_hyperparameters = HashMap::new();
931        for (param_name, weighted_sum) in aggregated_hyperparameters {
932            recommended_hyperparameters.insert(param_name, ParameterValue::Float(weighted_sum));
933        }
934
935        Ok(MetaLearningRecommendation {
936            recommended_hyperparameters,
937            confidence_scores,
938            expected_performance,
939            expected_runtime,
940            similar_datasets: vec![],
941            recommendation_source: "EnsembleWeightedAverage".to_string(),
942            uncertainty_estimate,
943        })
944    }
945}
946
947impl RandomForestSurrogate {
948    fn new(n_trees: usize, max_depth: Option<usize>) -> Self {
949        Self {
950            n_trees,
951            max_depth,
952            models: Vec::new(),
953        }
954    }
955}
956
957impl SurrogateModelTrait for RandomForestSurrogate {
958    fn fit(
959        &mut self,
960        _features: &Array2<Float>,
961        targets: &Array1<Float>,
962    ) -> Result<(), Box<dyn std::error::Error>> {
963        self.models.clear();
964
965        for _ in 0..self.n_trees {
966            let tree = SimpleTree {
967                feature_idx: None,
968                threshold: None,
969                left: None,
970                right: None,
971                prediction: Some(targets.mean().unwrap_or(0.0)),
972            };
973            self.models.push(tree);
974        }
975
976        Ok(())
977    }
978
979    fn predict(
980        &self,
981        features: &Array2<Float>,
982    ) -> Result<Array1<Float>, Box<dyn std::error::Error>> {
983        let n_samples = features.nrows();
984        let mut predictions = Array1::zeros(n_samples);
985
986        for i in 0..n_samples {
987            let mut sum = 0.0;
988            for tree in &self.models {
989                sum += tree.prediction.unwrap_or(0.0);
990            }
991            predictions[i] = sum / self.models.len() as Float;
992        }
993
994        Ok(predictions)
995    }
996
997    fn predict_with_uncertainty(
998        &self,
999        features: &Array2<Float>,
1000    ) -> Result<(Array1<Float>, Array1<Float>), Box<dyn std::error::Error>> {
1001        let predictions = self.predict(features)?;
1002        let uncertainties = Array1::from_elem(predictions.len(), 0.1); // Placeholder uncertainty
1003        Ok((predictions, uncertainties))
1004    }
1005}
1006
1007/// Convenience function for meta-learning based hyperparameter initialization
1008pub fn meta_learning_recommend(
1009    dataset_characteristics: &DatasetCharacteristics,
1010    algorithm_name: &str,
1011    historical_records: Vec<OptimizationRecord>,
1012    config: Option<MetaLearningConfig>,
1013) -> Result<MetaLearningRecommendation, Box<dyn std::error::Error>> {
1014    let config = config.unwrap_or_default();
1015    let mut engine = MetaLearningEngine::new(config);
1016    engine.load_historical_records(historical_records);
1017    engine.recommend_hyperparameters(dataset_characteristics, algorithm_name)
1018}
1019
1020#[allow(non_snake_case)]
1021#[cfg(test)]
1022mod tests {
1023    use super::*;
1024
1025    fn create_sample_dataset_characteristics() -> DatasetCharacteristics {
1026        DatasetCharacteristics {
1027            n_samples: 1000,
1028            n_features: 10,
1029            n_classes: Some(2),
1030            class_balance: vec![0.6, 0.4],
1031            feature_types: vec![FeatureType::Numerical; 10],
1032            statistical_measures: StatisticalMeasures {
1033                mean_values: vec![0.0; 10],
1034                std_values: vec![1.0; 10],
1035                skewness: vec![0.0; 10],
1036                kurtosis: vec![3.0; 10],
1037                correlation_matrix: None,
1038                mutual_information: None,
1039            },
1040            complexity_measures: ComplexityMeasures {
1041                fisher_discriminant_ratio: 1.5,
1042                volume_of_overlap: 0.3,
1043                feature_efficiency: 0.8,
1044                collective_feature_efficiency: 0.7,
1045                entropy: 0.9,
1046                class_probability_max: 0.6,
1047            },
1048            domain_specific: HashMap::new(),
1049        }
1050    }
1051
1052    #[test]
1053    fn test_meta_learning_engine_creation() {
1054        let config = MetaLearningConfig::default();
1055        let engine = MetaLearningEngine::new(config);
1056        assert_eq!(engine.historical_records.len(), 0);
1057    }
1058
1059    #[test]
1060    fn test_similarity_calculation() {
1061        let config = MetaLearningConfig::default();
1062        let engine = MetaLearningEngine::new(config);
1063
1064        let dataset1 = create_sample_dataset_characteristics();
1065        let dataset2 = create_sample_dataset_characteristics();
1066
1067        let similarity = engine
1068            .calculate_similarity(&dataset1, &dataset2, &SimilarityMetric::Cosine)
1069            .unwrap();
1070        assert!(similarity >= 0.0 && similarity <= 1.0);
1071    }
1072
1073    #[test]
1074    fn test_feature_extraction() {
1075        let config = MetaLearningConfig::default();
1076        let engine = MetaLearningEngine::new(config);
1077
1078        let dataset = create_sample_dataset_characteristics();
1079        let features = engine.extract_features(&dataset).unwrap();
1080
1081        assert!(features.len() > 0);
1082    }
1083
1084    #[test]
1085    fn test_meta_learning_recommendation() {
1086        let dataset_characteristics = create_sample_dataset_characteristics();
1087
1088        let mut hyperparameters = HashMap::new();
1089        hyperparameters.insert("learning_rate".to_string(), ParameterValue::Float(0.01));
1090        hyperparameters.insert("n_estimators".to_string(), ParameterValue::Integer(100));
1091
1092        let record = OptimizationRecord {
1093            dataset_id: "test_dataset".to_string(),
1094            algorithm_name: "RandomForest".to_string(),
1095            dataset_characteristics: dataset_characteristics.clone(),
1096            hyperparameters,
1097            performance_score: 0.85,
1098            optimization_time: 120.0,
1099            convergence_iterations: 50,
1100            validation_method: "5-fold-cv".to_string(),
1101            timestamp: 1234567890,
1102        };
1103
1104        let result =
1105            meta_learning_recommend(&dataset_characteristics, "RandomForest", vec![record], None);
1106
1107        // Should fail due to insufficient historical data
1108        assert!(result.is_err());
1109    }
1110}