1use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::random::rngs::StdRng;
9use scirs2_core::random::SeedableRng;
10use serde::{Deserialize, Serialize};
11use sklears_core::types::Float;
12use std::collections::HashMap;
13
14#[derive(Debug, Clone)]
16pub enum MetaLearningStrategy {
17 SimilarityBased {
19 similarity_metric: SimilarityMetric,
20
21 k_neighbors: usize,
22
23 weight_by_distance: bool,
24 },
25 ModelBased {
27 surrogate_model: SurrogateModel,
28
29 update_frequency: usize,
30 },
31 GradientBased {
33 meta_learning_rate: Float,
34 adaptation_steps: usize,
35 inner_learning_rate: Float,
36 },
37 BayesianMeta {
39 prior_strength: Float,
40 hierarchical_levels: usize,
41 },
42 TransferLearning {
44 transfer_method: TransferMethod,
45 domain_adaptation: bool,
46 },
47 EnsembleMeta {
49 strategies: Vec<MetaLearningStrategy>,
50 combination_method: CombinationMethod,
51 },
52}
53
54#[derive(Debug, Clone)]
56pub enum SimilarityMetric {
57 Cosine,
59 Euclidean,
61 Manhattan,
63 Correlation,
65 JensenShannon,
67 Learned,
69}
70
71#[derive(Debug, Clone)]
73pub enum SurrogateModel {
74 RandomForest {
76 n_trees: usize,
77
78 max_depth: Option<usize>,
79 },
80 GaussianProcess { kernel_type: String },
82 NeuralNetwork { hidden_layers: Vec<usize> },
84 LinearRegression { regularization: Float },
86}
87
88#[derive(Debug, Clone)]
90pub enum TransferMethod {
91 DirectTransfer,
93 FeatureTransfer,
95 ModelTransfer,
97 InstanceTransfer,
99}
100
101#[derive(Debug, Clone)]
103pub enum CombinationMethod {
104 Average,
106 WeightedAverage,
108 Stacking,
110 BayesianAveraging,
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct DatasetCharacteristics {
117 pub n_samples: usize,
118 pub n_features: usize,
119 pub n_classes: Option<usize>,
120 pub class_balance: Vec<Float>,
121 pub feature_types: Vec<FeatureType>,
122 pub statistical_measures: StatisticalMeasures,
123 pub complexity_measures: ComplexityMeasures,
124 pub domain_specific: HashMap<String, Float>,
125}
126
127#[derive(Debug, Clone, Serialize, Deserialize)]
129pub enum FeatureType {
130 Numerical,
132 Categorical,
134 Ordinal,
136 Text,
138 Image,
140 TimeSeries,
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct StatisticalMeasures {
147 pub mean_values: Vec<Float>,
148 pub std_values: Vec<Float>,
149 pub skewness: Vec<Float>,
150 pub kurtosis: Vec<Float>,
151 pub correlation_matrix: Option<Array2<Float>>,
152 pub mutual_information: Option<Vec<Float>>,
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct ComplexityMeasures {
158 pub fisher_discriminant_ratio: Float,
159 pub volume_of_overlap: Float,
160 pub feature_efficiency: Float,
161 pub collective_feature_efficiency: Float,
162 pub entropy: Float,
163 pub class_probability_max: Float,
164}
165
166#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct OptimizationRecord {
169 pub dataset_id: String,
170 pub algorithm_name: String,
171 pub dataset_characteristics: DatasetCharacteristics,
172 pub hyperparameters: HashMap<String, ParameterValue>,
173 pub performance_score: Float,
174 pub optimization_time: Float,
175 pub convergence_iterations: usize,
176 pub validation_method: String,
177 pub timestamp: u64,
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
182pub enum ParameterValue {
183 Float(Float),
185 Integer(i64),
187 Boolean(bool),
189 String(String),
191 Array(Vec<Float>),
193}
194
195#[derive(Debug, Clone)]
197pub struct MetaLearningRecommendation {
198 pub recommended_hyperparameters: HashMap<String, ParameterValue>,
199 pub confidence_scores: HashMap<String, Float>,
200 pub expected_performance: Float,
201 pub expected_runtime: Float,
202 pub similar_datasets: Vec<String>,
203 pub recommendation_source: String,
204 pub uncertainty_estimate: Float,
205}
206
207#[derive(Debug, Clone)]
209pub struct MetaLearningConfig {
210 pub strategy: MetaLearningStrategy,
211 pub min_historical_records: usize,
212 pub max_similarity_distance: Float,
213 pub confidence_threshold: Float,
214 pub update_interval: usize,
215 pub cache_size: usize,
216 pub random_state: Option<u64>,
217}
218
219#[derive(Debug)]
221pub struct MetaLearningEngine {
222 config: MetaLearningConfig,
223 historical_records: Vec<OptimizationRecord>,
224 dataset_similarity_cache: HashMap<String, Vec<(String, Float)>>,
225 surrogate_models: HashMap<String, Box<dyn SurrogateModelTrait>>,
226 rng: StdRng,
227}
228
229trait SurrogateModelTrait: std::fmt::Debug {
231 fn fit(
232 &mut self,
233 features: &Array2<Float>,
234 targets: &Array1<Float>,
235 ) -> Result<(), Box<dyn std::error::Error>>;
236 fn predict(
237 &self,
238 features: &Array2<Float>,
239 ) -> Result<Array1<Float>, Box<dyn std::error::Error>>;
240 fn predict_with_uncertainty(
241 &self,
242 features: &Array2<Float>,
243 ) -> Result<(Array1<Float>, Array1<Float>), Box<dyn std::error::Error>>;
244}
245
246#[derive(Debug)]
248struct RandomForestSurrogate {
249 n_trees: usize,
250 max_depth: Option<usize>,
251 models: Vec<SimpleTree>,
252}
253
254#[derive(Debug, Clone)]
256struct SimpleTree {
257 feature_idx: Option<usize>,
258 threshold: Option<Float>,
259 left: Option<Box<SimpleTree>>,
260 right: Option<Box<SimpleTree>>,
261 prediction: Option<Float>,
262}
263
264impl Default for MetaLearningConfig {
265 fn default() -> Self {
266 Self {
267 strategy: MetaLearningStrategy::SimilarityBased {
268 similarity_metric: SimilarityMetric::Cosine,
269 k_neighbors: 5,
270 weight_by_distance: true,
271 },
272 min_historical_records: 10,
273 max_similarity_distance: 0.8,
274 confidence_threshold: 0.6,
275 update_interval: 100,
276 cache_size: 1000,
277 random_state: None,
278 }
279 }
280}
281
282impl MetaLearningEngine {
283 pub fn new(config: MetaLearningConfig) -> Self {
285 let rng = match config.random_state {
286 Some(seed) => StdRng::seed_from_u64(seed),
287 None => {
288 use scirs2_core::random::thread_rng;
289 StdRng::from_rng(&mut thread_rng())
290 }
291 };
292
293 Self {
294 config,
295 historical_records: Vec::new(),
296 dataset_similarity_cache: HashMap::new(),
297 surrogate_models: HashMap::new(),
298 rng,
299 }
300 }
301
302 pub fn load_historical_records(&mut self, records: Vec<OptimizationRecord>) {
304 self.historical_records.extend(records);
305 self.update_surrogate_models().unwrap_or_else(|e| {
306 eprintln!("Warning: Failed to update surrogate models: {}", e);
307 });
308 }
309
310 pub fn add_record(&mut self, record: OptimizationRecord) {
312 self.historical_records.push(record);
313
314 if self.historical_records.len() % self.config.update_interval == 0 {
316 self.update_surrogate_models().unwrap_or_else(|e| {
317 eprintln!("Warning: Failed to update surrogate models: {}", e);
318 });
319 }
320 }
321
322 pub fn recommend_hyperparameters(
324 &mut self,
325 dataset_characteristics: &DatasetCharacteristics,
326 algorithm_name: &str,
327 ) -> Result<MetaLearningRecommendation, Box<dyn std::error::Error>> {
328 if self.historical_records.len() < self.config.min_historical_records {
329 return Err("Insufficient historical data for meta-learning".into());
330 }
331
332 match &self.config.strategy {
333 MetaLearningStrategy::SimilarityBased { .. } => {
334 self.similarity_based_recommendation(dataset_characteristics, algorithm_name)
335 }
336 MetaLearningStrategy::ModelBased { .. } => {
337 self.model_based_recommendation(dataset_characteristics, algorithm_name)
338 }
339 MetaLearningStrategy::GradientBased { .. } => {
340 self.gradient_based_recommendation(dataset_characteristics, algorithm_name)
341 }
342 MetaLearningStrategy::BayesianMeta { .. } => {
343 self.bayesian_meta_recommendation(dataset_characteristics, algorithm_name)
344 }
345 MetaLearningStrategy::TransferLearning { .. } => {
346 self.transfer_learning_recommendation(dataset_characteristics, algorithm_name)
347 }
348 MetaLearningStrategy::EnsembleMeta { .. } => {
349 self.ensemble_meta_recommendation(dataset_characteristics, algorithm_name)
350 }
351 }
352 }
353
354 fn similarity_based_recommendation(
356 &mut self,
357 dataset_characteristics: &DatasetCharacteristics,
358 algorithm_name: &str,
359 ) -> Result<MetaLearningRecommendation, Box<dyn std::error::Error>> {
360 let (similarity_metric, k_neighbors, weight_by_distance) = match &self.config.strategy {
361 MetaLearningStrategy::SimilarityBased {
362 similarity_metric,
363 k_neighbors,
364 weight_by_distance,
365 } => (similarity_metric, *k_neighbors, *weight_by_distance),
366 _ => unreachable!(),
367 };
368
369 let mut similarities = Vec::new();
371 for record in &self.historical_records {
372 if record.algorithm_name == algorithm_name {
373 let similarity = self.calculate_similarity(
374 dataset_characteristics,
375 &record.dataset_characteristics,
376 similarity_metric,
377 )?;
378 similarities.push((record, similarity));
379 }
380 }
381
382 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
384 similarities.truncate(k_neighbors);
385
386 if similarities.is_empty() {
387 return Err("No similar datasets found".into());
388 }
389
390 let mut aggregated_hyperparameters = HashMap::new();
392 let mut confidence_scores = HashMap::new();
393 let mut expected_performance = 0.0;
394 let mut expected_runtime = 0.0;
395 let mut total_weight = 0.0;
396
397 for (record, similarity) in &similarities {
398 let weight = if weight_by_distance { *similarity } else { 1.0 };
399 total_weight += weight;
400
401 expected_performance += record.performance_score * weight;
402 expected_runtime += record.optimization_time * weight;
403
404 for (param_name, param_value) in &record.hyperparameters {
405 match param_value {
406 ParameterValue::Float(val) => {
407 let entry = aggregated_hyperparameters
408 .entry(param_name.clone())
409 .or_insert_with(|| (0.0, 0.0)); entry.0 += val * weight;
411 entry.1 += weight;
412 }
413 ParameterValue::Integer(val) => {
414 let entry = aggregated_hyperparameters
415 .entry(param_name.clone())
416 .or_insert_with(|| (0.0, 0.0));
417 entry.0 += *val as Float * weight;
418 entry.1 += weight;
419 }
420 _ => {
421 }
424 }
425
426 confidence_scores.insert(param_name.clone(), *similarity);
427 }
428 }
429
430 let mut recommended_hyperparameters = HashMap::new();
432 for (param_name, (sum, weight_sum)) in aggregated_hyperparameters {
433 let avg_value = sum / weight_sum;
434 recommended_hyperparameters.insert(param_name, ParameterValue::Float(avg_value));
435 }
436
437 expected_performance /= total_weight;
438 expected_runtime /= total_weight;
439
440 let similar_datasets = similarities
441 .iter()
442 .map(|(record, _)| record.dataset_id.clone())
443 .collect();
444
445 let uncertainty_estimate = 1.0
446 - similarities.iter().map(|(_, sim)| sim).sum::<Float>() / similarities.len() as Float;
447
448 Ok(MetaLearningRecommendation {
449 recommended_hyperparameters,
450 confidence_scores,
451 expected_performance,
452 expected_runtime,
453 similar_datasets,
454 recommendation_source: "SimilarityBased".to_string(),
455 uncertainty_estimate,
456 })
457 }
458
459 fn model_based_recommendation(
461 &mut self,
462 dataset_characteristics: &DatasetCharacteristics,
463 algorithm_name: &str,
464 ) -> Result<MetaLearningRecommendation, Box<dyn std::error::Error>> {
465 let model_key = format!("{}_{}", algorithm_name, "hyperparams");
466
467 if let Some(model) = self.surrogate_models.get(&model_key) {
468 let features = self.extract_features(dataset_characteristics)?;
469 let features_2d = Array2::from_shape_vec((1, features.len()), features.to_vec())?;
470
471 let (predictions, uncertainties) = model.predict_with_uncertainty(&features_2d)?;
472
473 let mut recommended_hyperparameters = HashMap::new();
475 let mut confidence_scores = HashMap::new();
476
477 recommended_hyperparameters.insert(
479 "learning_rate".to_string(),
480 ParameterValue::Float(predictions[0]),
481 );
482 confidence_scores.insert("learning_rate".to_string(), 1.0 - uncertainties[0]);
483
484 Ok(MetaLearningRecommendation {
485 recommended_hyperparameters,
486 confidence_scores,
487 expected_performance: predictions[0],
488 expected_runtime: 100.0, similar_datasets: vec![],
490 recommendation_source: "ModelBased".to_string(),
491 uncertainty_estimate: uncertainties[0],
492 })
493 } else {
494 Err("Surrogate model not available".into())
495 }
496 }
497
498 fn gradient_based_recommendation(
500 &mut self,
501 _dataset_characteristics: &DatasetCharacteristics,
502 algorithm_name: &str,
503 ) -> Result<MetaLearningRecommendation, Box<dyn std::error::Error>> {
504 let similar_records: Vec<&OptimizationRecord> = self
508 .historical_records
509 .iter()
510 .filter(|r| r.algorithm_name == algorithm_name)
511 .collect();
512
513 if similar_records.is_empty() {
514 return Err("No historical records for algorithm".into());
515 }
516
517 let mut adapted_hyperparameters = HashMap::new();
519 let mut confidence_scores = HashMap::new();
520
521 let best_record = similar_records
523 .iter()
524 .max_by(|a, b| {
525 a.performance_score
526 .partial_cmp(&b.performance_score)
527 .unwrap()
528 })
529 .unwrap();
530
531 for (param_name, param_value) in &best_record.hyperparameters {
532 adapted_hyperparameters.insert(param_name.clone(), param_value.clone());
533 confidence_scores.insert(param_name.clone(), 0.8); }
535
536 Ok(MetaLearningRecommendation {
537 recommended_hyperparameters: adapted_hyperparameters,
538 confidence_scores,
539 expected_performance: best_record.performance_score,
540 expected_runtime: best_record.optimization_time,
541 similar_datasets: vec![best_record.dataset_id.clone()],
542 recommendation_source: "GradientBased".to_string(),
543 uncertainty_estimate: 0.2,
544 })
545 }
546
547 fn bayesian_meta_recommendation(
549 &mut self,
550 _dataset_characteristics: &DatasetCharacteristics,
551 algorithm_name: &str,
552 ) -> Result<MetaLearningRecommendation, Box<dyn std::error::Error>> {
553 let relevant_records: Vec<&OptimizationRecord> = self
555 .historical_records
556 .iter()
557 .filter(|r| r.algorithm_name == algorithm_name)
558 .collect();
559
560 if relevant_records.is_empty() {
561 return Err("No historical records for algorithm".into());
562 }
563
564 let mut hyperparameter_distributions = HashMap::new();
566 let mut confidence_scores = HashMap::new();
567
568 for record in &relevant_records {
569 for (param_name, param_value) in &record.hyperparameters {
570 if let ParameterValue::Float(val) = param_value {
571 let entry = hyperparameter_distributions
572 .entry(param_name.clone())
573 .or_insert_with(Vec::new);
574 entry.push(*val);
575 }
576 }
577 }
578
579 let mut recommended_hyperparameters = HashMap::new();
580 for (param_name, values) in hyperparameter_distributions {
581 let mean = values.iter().sum::<Float>() / values.len() as Float;
582 let variance =
583 values.iter().map(|v| (v - mean).powi(2)).sum::<Float>() / values.len() as Float;
584
585 recommended_hyperparameters.insert(param_name.clone(), ParameterValue::Float(mean));
586 confidence_scores.insert(param_name, 1.0 / (1.0 + variance)); }
588
589 let avg_performance = relevant_records
590 .iter()
591 .map(|r| r.performance_score)
592 .sum::<Float>()
593 / relevant_records.len() as Float;
594
595 Ok(MetaLearningRecommendation {
596 recommended_hyperparameters,
597 confidence_scores,
598 expected_performance: avg_performance,
599 expected_runtime: 100.0, similar_datasets: relevant_records
601 .iter()
602 .map(|r| r.dataset_id.clone())
603 .collect(),
604 recommendation_source: "BayesianMeta".to_string(),
605 uncertainty_estimate: 0.3,
606 })
607 }
608
609 fn transfer_learning_recommendation(
611 &mut self,
612 dataset_characteristics: &DatasetCharacteristics,
613 algorithm_name: &str,
614 ) -> Result<MetaLearningRecommendation, Box<dyn std::error::Error>> {
615 let mut best_similarity = 0.0;
617 let mut best_record = None;
618
619 for record in &self.historical_records {
620 if record.algorithm_name == algorithm_name {
621 let similarity = self.calculate_similarity(
622 dataset_characteristics,
623 &record.dataset_characteristics,
624 &SimilarityMetric::Cosine,
625 )?;
626
627 if similarity > best_similarity {
628 best_similarity = similarity;
629 best_record = Some(record);
630 }
631 }
632 }
633
634 if let Some(record) = best_record {
635 let mut confidence_scores = HashMap::new();
636 for param_name in record.hyperparameters.keys() {
637 confidence_scores.insert(param_name.clone(), best_similarity);
638 }
639
640 Ok(MetaLearningRecommendation {
641 recommended_hyperparameters: record.hyperparameters.clone(),
642 confidence_scores,
643 expected_performance: record.performance_score * best_similarity,
644 expected_runtime: record.optimization_time,
645 similar_datasets: vec![record.dataset_id.clone()],
646 recommendation_source: "TransferLearning".to_string(),
647 uncertainty_estimate: 1.0 - best_similarity,
648 })
649 } else {
650 Err("No suitable source domain found for transfer learning".into())
651 }
652 }
653
654 fn ensemble_meta_recommendation(
656 &mut self,
657 dataset_characteristics: &DatasetCharacteristics,
658 algorithm_name: &str,
659 ) -> Result<MetaLearningRecommendation, Box<dyn std::error::Error>> {
660 let (strategies, combination_method) = match &self.config.strategy {
661 MetaLearningStrategy::EnsembleMeta {
662 strategies,
663 combination_method,
664 } => (strategies, combination_method),
665 _ => unreachable!(),
666 };
667
668 let mut recommendations = Vec::new();
669
670 for strategy in strategies {
672 let mut temp_config = self.config.clone();
673 temp_config.strategy = strategy.clone();
674 let mut temp_engine = MetaLearningEngine::new(temp_config);
675 temp_engine.historical_records = self.historical_records.clone();
676
677 if let Ok(rec) =
678 temp_engine.recommend_hyperparameters(dataset_characteristics, algorithm_name)
679 {
680 recommendations.push(rec);
681 }
682 }
683
684 if recommendations.is_empty() {
685 return Err("No recommendations from ensemble strategies".into());
686 }
687
688 match combination_method {
690 CombinationMethod::Average => self.average_recommendations(recommendations),
691 CombinationMethod::WeightedAverage => {
692 self.weighted_average_recommendations(recommendations)
693 }
694 _ => {
695 self.average_recommendations(recommendations)
697 }
698 }
699 }
700
701 fn calculate_similarity(
703 &self,
704 dataset1: &DatasetCharacteristics,
705 dataset2: &DatasetCharacteristics,
706 metric: &SimilarityMetric,
707 ) -> Result<Float, Box<dyn std::error::Error>> {
708 let features1 = self.extract_features(dataset1)?;
709 let features2 = self.extract_features(dataset2)?;
710
711 match metric {
712 SimilarityMetric::Cosine => {
713 let dot_product = features1
714 .iter()
715 .zip(features2.iter())
716 .map(|(a, b)| a * b)
717 .sum::<Float>();
718 let norm1 = (features1.iter().map(|x| x * x).sum::<Float>()).sqrt();
719 let norm2 = (features2.iter().map(|x| x * x).sum::<Float>()).sqrt();
720 Ok(dot_product / (norm1 * norm2))
721 }
722 SimilarityMetric::Euclidean => {
723 let distance = features1
724 .iter()
725 .zip(features2.iter())
726 .map(|(a, b)| (a - b).powi(2))
727 .sum::<Float>()
728 .sqrt();
729 Ok(1.0 / (1.0 + distance))
730 }
731 SimilarityMetric::Manhattan => {
732 let distance = features1
733 .iter()
734 .zip(features2.iter())
735 .map(|(a, b)| (a - b).abs())
736 .sum::<Float>();
737 Ok(1.0 / (1.0 + distance))
738 }
739 _ => {
740 let dot_product = features1
742 .iter()
743 .zip(features2.iter())
744 .map(|(a, b)| a * b)
745 .sum::<Float>();
746 let norm1 = (features1.iter().map(|x| x * x).sum::<Float>()).sqrt();
747 let norm2 = (features2.iter().map(|x| x * x).sum::<Float>()).sqrt();
748 Ok(dot_product / (norm1 * norm2))
749 }
750 }
751 }
752
753 fn extract_features(
755 &self,
756 characteristics: &DatasetCharacteristics,
757 ) -> Result<Array1<Float>, Box<dyn std::error::Error>> {
758 let mut features = Vec::new();
759
760 features.push(characteristics.n_samples as Float);
762 features.push(characteristics.n_features as Float);
763 features.push(characteristics.n_classes.unwrap_or(0) as Float);
764
765 if !characteristics.statistical_measures.mean_values.is_empty() {
767 features.extend(&characteristics.statistical_measures.mean_values);
768 }
769
770 features.push(
772 characteristics
773 .complexity_measures
774 .fisher_discriminant_ratio,
775 );
776 features.push(characteristics.complexity_measures.volume_of_overlap);
777 features.push(characteristics.complexity_measures.feature_efficiency);
778 features.push(characteristics.complexity_measures.entropy);
779
780 Ok(Array1::from_vec(features))
781 }
782
783 fn update_surrogate_models(&mut self) -> Result<(), Box<dyn std::error::Error>> {
785 let mut algorithm_groups: HashMap<String, Vec<&OptimizationRecord>> = HashMap::new();
787
788 for record in &self.historical_records {
789 algorithm_groups
790 .entry(record.algorithm_name.clone())
791 .or_default()
792 .push(record);
793 }
794
795 for (algorithm_name, records) in algorithm_groups {
797 if records.len() >= 5 {
798 let model_key = format!("{}_{}", algorithm_name, "hyperparams");
800
801 let mut features_vec = Vec::new();
803 let mut targets = Vec::new();
804
805 for record in &records {
806 let features = self.extract_features(&record.dataset_characteristics)?;
807 features_vec.extend(features.to_vec());
808 targets.push(record.performance_score);
809 }
810
811 let n_features = self
812 .extract_features(&records[0].dataset_characteristics)?
813 .len();
814 let features_2d =
815 Array2::from_shape_vec((records.len(), n_features), features_vec)?;
816 let targets_1d = Array1::from_vec(targets);
817
818 let mut surrogate = Box::new(RandomForestSurrogate::new(10, Some(5)));
820 surrogate.fit(&features_2d, &targets_1d)?;
821
822 self.surrogate_models.insert(model_key, surrogate);
823 }
824 }
825
826 Ok(())
827 }
828
829 fn average_recommendations(
831 &self,
832 recommendations: Vec<MetaLearningRecommendation>,
833 ) -> Result<MetaLearningRecommendation, Box<dyn std::error::Error>> {
834 if recommendations.is_empty() {
835 return Err("No recommendations to average".into());
836 }
837
838 let mut aggregated_hyperparameters = HashMap::new();
839 let mut confidence_scores = HashMap::new();
840 let mut expected_performance = 0.0;
841 let mut expected_runtime = 0.0;
842 let mut uncertainty_estimate = 0.0;
843
844 let n_recommendations = recommendations.len() as Float;
845
846 for rec in &recommendations {
847 expected_performance += rec.expected_performance;
848 expected_runtime += rec.expected_runtime;
849 uncertainty_estimate += rec.uncertainty_estimate;
850
851 for (param_name, param_value) in &rec.recommended_hyperparameters {
852 if let ParameterValue::Float(val) = param_value {
853 *aggregated_hyperparameters
854 .entry(param_name.clone())
855 .or_insert(0.0) += val;
856 }
857 }
858
859 for (param_name, confidence) in &rec.confidence_scores {
860 *confidence_scores.entry(param_name.clone()).or_insert(0.0) += confidence;
861 }
862 }
863
864 let mut recommended_hyperparameters = HashMap::new();
866 for (param_name, sum) in aggregated_hyperparameters {
867 recommended_hyperparameters.insert(
868 param_name.clone(),
869 ParameterValue::Float(sum / n_recommendations),
870 );
871 if let Some(conf_sum) = confidence_scores.get_mut(¶m_name) {
872 *conf_sum /= n_recommendations;
873 }
874 }
875
876 Ok(MetaLearningRecommendation {
877 recommended_hyperparameters,
878 confidence_scores,
879 expected_performance: expected_performance / n_recommendations,
880 expected_runtime: expected_runtime / n_recommendations,
881 similar_datasets: vec![], recommendation_source: "EnsembleAverage".to_string(),
883 uncertainty_estimate: uncertainty_estimate / n_recommendations,
884 })
885 }
886
887 fn weighted_average_recommendations(
889 &self,
890 recommendations: Vec<MetaLearningRecommendation>,
891 ) -> Result<MetaLearningRecommendation, Box<dyn std::error::Error>> {
892 if recommendations.is_empty() {
893 return Err("No recommendations to average".into());
894 }
895
896 let weights: Vec<Float> = recommendations
898 .iter()
899 .map(|r| 1.0 - r.uncertainty_estimate)
900 .collect();
901
902 let total_weight: Float = weights.iter().sum();
903
904 let mut aggregated_hyperparameters = HashMap::new();
905 let mut confidence_scores = HashMap::new();
906 let mut expected_performance = 0.0;
907 let mut expected_runtime = 0.0;
908 let mut uncertainty_estimate = 0.0;
909
910 for (i, rec) in recommendations.iter().enumerate() {
911 let weight = weights[i] / total_weight;
912
913 expected_performance += rec.expected_performance * weight;
914 expected_runtime += rec.expected_runtime * weight;
915 uncertainty_estimate += rec.uncertainty_estimate * weight;
916
917 for (param_name, param_value) in &rec.recommended_hyperparameters {
918 if let ParameterValue::Float(val) = param_value {
919 *aggregated_hyperparameters
920 .entry(param_name.clone())
921 .or_insert(0.0) += val * weight;
922 }
923 }
924
925 for (param_name, confidence) in &rec.confidence_scores {
926 *confidence_scores.entry(param_name.clone()).or_insert(0.0) += confidence * weight;
927 }
928 }
929
930 let mut recommended_hyperparameters = HashMap::new();
931 for (param_name, weighted_sum) in aggregated_hyperparameters {
932 recommended_hyperparameters.insert(param_name, ParameterValue::Float(weighted_sum));
933 }
934
935 Ok(MetaLearningRecommendation {
936 recommended_hyperparameters,
937 confidence_scores,
938 expected_performance,
939 expected_runtime,
940 similar_datasets: vec![],
941 recommendation_source: "EnsembleWeightedAverage".to_string(),
942 uncertainty_estimate,
943 })
944 }
945}
946
947impl RandomForestSurrogate {
948 fn new(n_trees: usize, max_depth: Option<usize>) -> Self {
949 Self {
950 n_trees,
951 max_depth,
952 models: Vec::new(),
953 }
954 }
955}
956
957impl SurrogateModelTrait for RandomForestSurrogate {
958 fn fit(
959 &mut self,
960 _features: &Array2<Float>,
961 targets: &Array1<Float>,
962 ) -> Result<(), Box<dyn std::error::Error>> {
963 self.models.clear();
964
965 for _ in 0..self.n_trees {
966 let tree = SimpleTree {
967 feature_idx: None,
968 threshold: None,
969 left: None,
970 right: None,
971 prediction: Some(targets.mean().unwrap_or(0.0)),
972 };
973 self.models.push(tree);
974 }
975
976 Ok(())
977 }
978
979 fn predict(
980 &self,
981 features: &Array2<Float>,
982 ) -> Result<Array1<Float>, Box<dyn std::error::Error>> {
983 let n_samples = features.nrows();
984 let mut predictions = Array1::zeros(n_samples);
985
986 for i in 0..n_samples {
987 let mut sum = 0.0;
988 for tree in &self.models {
989 sum += tree.prediction.unwrap_or(0.0);
990 }
991 predictions[i] = sum / self.models.len() as Float;
992 }
993
994 Ok(predictions)
995 }
996
997 fn predict_with_uncertainty(
998 &self,
999 features: &Array2<Float>,
1000 ) -> Result<(Array1<Float>, Array1<Float>), Box<dyn std::error::Error>> {
1001 let predictions = self.predict(features)?;
1002 let uncertainties = Array1::from_elem(predictions.len(), 0.1); Ok((predictions, uncertainties))
1004 }
1005}
1006
1007pub fn meta_learning_recommend(
1009 dataset_characteristics: &DatasetCharacteristics,
1010 algorithm_name: &str,
1011 historical_records: Vec<OptimizationRecord>,
1012 config: Option<MetaLearningConfig>,
1013) -> Result<MetaLearningRecommendation, Box<dyn std::error::Error>> {
1014 let config = config.unwrap_or_default();
1015 let mut engine = MetaLearningEngine::new(config);
1016 engine.load_historical_records(historical_records);
1017 engine.recommend_hyperparameters(dataset_characteristics, algorithm_name)
1018}
1019
1020#[allow(non_snake_case)]
1021#[cfg(test)]
1022mod tests {
1023 use super::*;
1024
1025 fn create_sample_dataset_characteristics() -> DatasetCharacteristics {
1026 DatasetCharacteristics {
1027 n_samples: 1000,
1028 n_features: 10,
1029 n_classes: Some(2),
1030 class_balance: vec![0.6, 0.4],
1031 feature_types: vec![FeatureType::Numerical; 10],
1032 statistical_measures: StatisticalMeasures {
1033 mean_values: vec![0.0; 10],
1034 std_values: vec![1.0; 10],
1035 skewness: vec![0.0; 10],
1036 kurtosis: vec![3.0; 10],
1037 correlation_matrix: None,
1038 mutual_information: None,
1039 },
1040 complexity_measures: ComplexityMeasures {
1041 fisher_discriminant_ratio: 1.5,
1042 volume_of_overlap: 0.3,
1043 feature_efficiency: 0.8,
1044 collective_feature_efficiency: 0.7,
1045 entropy: 0.9,
1046 class_probability_max: 0.6,
1047 },
1048 domain_specific: HashMap::new(),
1049 }
1050 }
1051
1052 #[test]
1053 fn test_meta_learning_engine_creation() {
1054 let config = MetaLearningConfig::default();
1055 let engine = MetaLearningEngine::new(config);
1056 assert_eq!(engine.historical_records.len(), 0);
1057 }
1058
1059 #[test]
1060 fn test_similarity_calculation() {
1061 let config = MetaLearningConfig::default();
1062 let engine = MetaLearningEngine::new(config);
1063
1064 let dataset1 = create_sample_dataset_characteristics();
1065 let dataset2 = create_sample_dataset_characteristics();
1066
1067 let similarity = engine
1068 .calculate_similarity(&dataset1, &dataset2, &SimilarityMetric::Cosine)
1069 .unwrap();
1070 assert!(similarity >= 0.0 && similarity <= 1.0);
1071 }
1072
1073 #[test]
1074 fn test_feature_extraction() {
1075 let config = MetaLearningConfig::default();
1076 let engine = MetaLearningEngine::new(config);
1077
1078 let dataset = create_sample_dataset_characteristics();
1079 let features = engine.extract_features(&dataset).unwrap();
1080
1081 assert!(features.len() > 0);
1082 }
1083
1084 #[test]
1085 fn test_meta_learning_recommendation() {
1086 let dataset_characteristics = create_sample_dataset_characteristics();
1087
1088 let mut hyperparameters = HashMap::new();
1089 hyperparameters.insert("learning_rate".to_string(), ParameterValue::Float(0.01));
1090 hyperparameters.insert("n_estimators".to_string(), ParameterValue::Integer(100));
1091
1092 let record = OptimizationRecord {
1093 dataset_id: "test_dataset".to_string(),
1094 algorithm_name: "RandomForest".to_string(),
1095 dataset_characteristics: dataset_characteristics.clone(),
1096 hyperparameters,
1097 performance_score: 0.85,
1098 optimization_time: 120.0,
1099 convergence_iterations: 50,
1100 validation_method: "5-fold-cv".to_string(),
1101 timestamp: 1234567890,
1102 };
1103
1104 let result =
1105 meta_learning_recommend(&dataset_characteristics, "RandomForest", vec![record], None);
1106
1107 assert!(result.is_err());
1109 }
1110}