1use crate::ParameterValue;
8#[cfg(feature = "serde")]
9use serde::{Deserialize, Serialize};
10#[cfg(feature = "serde")]
11use sklears_core::error::{Result, SklearsError};
12use std::collections::HashMap;
13
14#[derive(Debug, Clone)]
16pub struct WarmStartConfig {
17 pub max_history_size: usize,
19 pub use_top_k_init: bool,
21 pub top_k: usize,
23 pub use_surrogate_warmstart: bool,
25 pub weight_decay: f64,
27 pub min_weight: f64,
29 pub adapt_parameter_ranges: bool,
31 pub use_transfer_learning: bool,
33}
34
35impl Default for WarmStartConfig {
36 fn default() -> Self {
37 Self {
38 max_history_size: 1000,
39 use_top_k_init: true,
40 top_k: 10,
41 use_surrogate_warmstart: true,
42 weight_decay: 0.95,
43 min_weight: 0.1,
44 adapt_parameter_ranges: true,
45 use_transfer_learning: false,
46 }
47 }
48}
49
50#[derive(Debug, Clone)]
52#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
53pub struct EvaluationRecord {
54 pub parameters: ParameterValue,
56 pub score: f64,
58 pub timestamp: u64,
60 pub cv_scores: Option<Vec<f64>>,
62 pub score_std: Option<f64>,
64 pub duration_ms: Option<u64>,
66 pub metadata: HashMap<String, String>,
68}
69
70impl EvaluationRecord {
71 pub fn new(parameters: ParameterValue, score: f64) -> Self {
72 Self {
73 parameters,
74 score,
75 timestamp: std::time::SystemTime::now()
76 .duration_since(std::time::UNIX_EPOCH)
77 .unwrap_or_default()
78 .as_secs(),
79 cv_scores: None,
80 score_std: None,
81 duration_ms: None,
82 metadata: HashMap::new(),
83 }
84 }
85
86 pub fn with_cv_scores(mut self, cv_scores: Vec<f64>) -> Self {
87 self.score_std = Some(statistical_std(&cv_scores));
88 self.cv_scores = Some(cv_scores);
89 self
90 }
91
92 pub fn with_duration(mut self, duration_ms: u64) -> Self {
93 self.duration_ms = Some(duration_ms);
94 self
95 }
96
97 pub fn with_metadata(mut self, key: String, value: String) -> Self {
98 self.metadata.insert(key, value);
99 self
100 }
101
102 pub fn age_weight(&self, decay_factor: f64, min_weight: f64) -> f64 {
104 let current_time = std::time::SystemTime::now()
105 .duration_since(std::time::UNIX_EPOCH)
106 .unwrap_or_default()
107 .as_secs();
108
109 let age_hours = (current_time - self.timestamp) / 3600;
110 let weight = decay_factor.powf(age_hours as f64);
111 weight.max(min_weight)
112 }
113}
114
115#[derive(Debug, Clone)]
117pub struct OptimizationHistory {
118 records: Vec<EvaluationRecord>,
120 config: WarmStartConfig,
122 problem_signature: Option<String>,
124}
125
126impl OptimizationHistory {
127 pub fn new(config: WarmStartConfig) -> Self {
128 Self {
129 records: Vec::new(),
130 config,
131 problem_signature: None,
132 }
133 }
134
135 pub fn add_record(&mut self, record: EvaluationRecord) {
137 self.records.push(record);
138
139 if self.records.len() > self.config.max_history_size {
141 self.records.sort_by(|a, b| {
143 let score_diff = b
144 .score
145 .partial_cmp(&a.score)
146 .unwrap_or(std::cmp::Ordering::Equal);
147 if score_diff == std::cmp::Ordering::Equal {
148 b.timestamp.cmp(&a.timestamp) } else {
150 score_diff
151 }
152 });
153
154 self.records.truncate(self.config.max_history_size);
155 }
156 }
157
158 pub fn get_top_k(&self, k: usize) -> Vec<&EvaluationRecord> {
160 let mut sorted_records = self.records.iter().collect::<Vec<_>>();
161 sorted_records.sort_by(|a, b| {
162 b.score
163 .partial_cmp(&a.score)
164 .unwrap_or(std::cmp::Ordering::Equal)
165 });
166 sorted_records.into_iter().take(k).collect()
167 }
168
169 pub fn get_weighted_data(&self) -> Vec<(ParameterValue, f64, f64)> {
171 self.records
172 .iter()
173 .map(|record| {
174 let weight = record.age_weight(self.config.weight_decay, self.config.min_weight);
175 (record.parameters.clone(), record.score, weight)
176 })
177 .collect()
178 }
179
180 pub fn get_similar_configs(
182 &self,
183 target_params: &ParameterValue,
184 similarity_threshold: f64,
185 max_count: usize,
186 ) -> Vec<&EvaluationRecord> {
187 let mut similar = self
188 .records
189 .iter()
190 .filter_map(|record| {
191 let distance = parameter_distance(target_params, &record.parameters);
192 if distance <= similarity_threshold {
193 Some((record, distance))
194 } else {
195 None
196 }
197 })
198 .collect::<Vec<_>>();
199
200 similar.sort_by(|a, b| {
202 let dist_cmp = a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal);
203 if dist_cmp == std::cmp::Ordering::Equal {
204 b.0.score
205 .partial_cmp(&a.0.score)
206 .unwrap_or(std::cmp::Ordering::Equal)
207 } else {
208 dist_cmp
209 }
210 });
211
212 similar
213 .into_iter()
214 .take(max_count)
215 .map(|(record, _)| record)
216 .collect()
217 }
218
219 pub fn analyze_parameter_ranges(&self) -> HashMap<String, (f64, f64, f64)> {
221 let mut param_analysis = HashMap::new();
222
223 if self.records.is_empty() {
224 return param_analysis;
225 }
226
227 let mut float_params: HashMap<String, Vec<f64>> = HashMap::new();
229
230 for record in &self.records {
231 if let ParameterValue::Float(val) = &record.parameters {
232 float_params
233 .entry("default".to_string())
234 .or_default()
235 .push(*val);
236 }
237 }
239
240 for (param_name, values) in float_params {
241 if !values.is_empty() {
242 let min_val = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
243 let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
244 let mean_val = values.iter().sum::<f64>() / values.len() as f64;
245
246 param_analysis.insert(param_name, (min_val, max_val, mean_val));
247 }
248 }
249
250 param_analysis
251 }
252
253 pub fn get_statistics(&self) -> OptimizationStatistics {
255 if self.records.is_empty() {
256 return OptimizationStatistics::default();
257 }
258
259 let scores: Vec<f64> = self.records.iter().map(|r| r.score).collect();
260 let durations: Vec<u64> = self.records.iter().filter_map(|r| r.duration_ms).collect();
261
262 let best_score = scores.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
263 let worst_score = scores.iter().fold(f64::INFINITY, |a, &b| a.min(b));
264 let mean_score = scores.iter().sum::<f64>() / scores.len() as f64;
265 let score_std = statistical_std(&scores);
266
267 let mean_duration = if !durations.is_empty() {
268 durations.iter().sum::<u64>() as f64 / durations.len() as f64
269 } else {
270 0.0
271 };
272
273 OptimizationStatistics {
274 total_evaluations: self.records.len(),
275 best_score,
276 worst_score,
277 mean_score,
278 score_std,
279 mean_duration_ms: mean_duration,
280 unique_configurations: self.count_unique_configurations(),
281 }
282 }
283
284 fn count_unique_configurations(&self) -> usize {
285 let mut unique_params = std::collections::HashSet::new();
286 for record in &self.records {
287 unique_params.insert(format!("{:?}", record.parameters));
288 }
289 unique_params.len()
290 }
291
292 pub fn set_problem_signature(&mut self, signature: String) {
294 self.problem_signature = Some(signature);
295 }
296
297 pub fn problem_signature(&self) -> Option<&str> {
299 self.problem_signature.as_deref()
300 }
301
302 #[cfg(feature = "serde")]
304 pub fn export_to_json(&self) -> Result<String> {
305 serde_json::to_string_pretty(&self.records)
306 .map_err(|e| SklearsError::InvalidInput(format!("Failed to export history: {}", e)))
307 }
308
309 #[cfg(feature = "serde")]
311 pub fn import_from_json(&mut self, json_data: &str) -> Result<()> {
312 let imported_records: Vec<EvaluationRecord> = serde_json::from_str(json_data)
313 .map_err(|e| SklearsError::InvalidInput(format!("Failed to import history: {}", e)))?;
314
315 for record in imported_records {
316 self.add_record(record);
317 }
318
319 Ok(())
320 }
321}
322
323#[derive(Debug, Clone)]
325pub struct OptimizationStatistics {
326 pub total_evaluations: usize,
327 pub best_score: f64,
328 pub worst_score: f64,
329 pub mean_score: f64,
330 pub score_std: f64,
331 pub mean_duration_ms: f64,
332 pub unique_configurations: usize,
333}
334
335impl Default for OptimizationStatistics {
336 fn default() -> Self {
337 Self {
338 total_evaluations: 0,
339 best_score: f64::NEG_INFINITY,
340 worst_score: f64::INFINITY,
341 mean_score: 0.0,
342 score_std: 0.0,
343 mean_duration_ms: 0.0,
344 unique_configurations: 0,
345 }
346 }
347}
348
349#[derive(Debug, Clone)]
351pub enum WarmStartStrategy {
352 TopK(usize),
354 WeightedSampling(usize),
356 SurrogateModel(usize),
358 ClusterBased(usize),
360 Combined(Vec<WarmStartStrategy>),
362}
363
364pub struct WarmStartInitializer {
366 history: OptimizationHistory,
367 strategy: WarmStartStrategy,
368 config: WarmStartConfig,
369}
370
371impl WarmStartInitializer {
372 pub fn new(
373 history: OptimizationHistory,
374 strategy: WarmStartStrategy,
375 config: WarmStartConfig,
376 ) -> Self {
377 Self {
378 history,
379 strategy,
380 config,
381 }
382 }
383
384 pub fn generate_initial_points(&self, n_points: usize) -> Vec<ParameterValue> {
386 match &self.strategy {
387 WarmStartStrategy::TopK(k) => self.generate_top_k_points(*k.min(&n_points)),
388 WarmStartStrategy::WeightedSampling(n) => {
389 self.generate_weighted_sample_points(*n.min(&n_points))
390 }
391 WarmStartStrategy::SurrogateModel(n) => {
392 self.generate_surrogate_points(*n.min(&n_points))
393 }
394 WarmStartStrategy::ClusterBased(n) => {
395 self.generate_cluster_based_points(*n.min(&n_points))
396 }
397 WarmStartStrategy::Combined(strategies) => {
398 self.generate_combined_points(strategies, n_points)
399 }
400 }
401 }
402
403 fn generate_top_k_points(&self, k: usize) -> Vec<ParameterValue> {
404 self.history
405 .get_top_k(k)
406 .into_iter()
407 .map(|record| record.parameters.clone())
408 .collect()
409 }
410
411 fn generate_weighted_sample_points(&self, n: usize) -> Vec<ParameterValue> {
412 let weighted_data = self.history.get_weighted_data();
413 if weighted_data.is_empty() {
414 return Vec::new();
415 }
416
417 let total_weight: f64 = weighted_data.iter().map(|(_, _, w)| w).sum();
419 let mut cumulative_weights = Vec::new();
420 let mut running_sum = 0.0;
421
422 for (_, _, weight) in &weighted_data {
423 running_sum += weight / total_weight;
424 cumulative_weights.push(running_sum);
425 }
426
427 use scirs2_core::random::prelude::*;
429 let mut rng = thread_rng();
430 let mut selected = Vec::new();
431
432 for _ in 0..n {
433 let random_val: f64 = rng.gen();
434 for (i, &cum_weight) in cumulative_weights.iter().enumerate() {
435 if random_val <= cum_weight {
436 selected.push(weighted_data[i].0.clone());
437 break;
438 }
439 }
440 }
441
442 selected
443 }
444
445 fn generate_surrogate_points(&self, n: usize) -> Vec<ParameterValue> {
446 let top_configs = self.history.get_top_k(n.min(10));
449
450 if top_configs.is_empty() {
451 return Vec::new();
452 }
453
454 top_configs
457 .into_iter()
458 .take(n)
459 .map(|record| record.parameters.clone())
460 .collect()
461 }
462
463 fn generate_cluster_based_points(&self, n: usize) -> Vec<ParameterValue> {
464 let good_configs = self.history.get_top_k(n * 2); if good_configs.is_empty() {
469 return Vec::new();
470 }
471
472 let mut clusters: Vec<Vec<EvaluationRecord>> = Vec::new();
474 let similarity_threshold = 0.5;
475
476 for config in good_configs {
477 let mut assigned = false;
478 for cluster in &mut clusters {
479 let cluster_center: &EvaluationRecord = &cluster[0];
480 let distance = parameter_distance(&config.parameters, &cluster_center.parameters);
481
482 if distance <= similarity_threshold {
483 cluster.push(config.clone());
484 assigned = true;
485 break;
486 }
487 }
488
489 if !assigned {
490 clusters.push(vec![config.clone()]);
491 }
492 }
493
494 clusters
496 .into_iter()
497 .take(n)
498 .map(|cluster| cluster[0].parameters.clone())
499 .collect()
500 }
501
502 fn generate_combined_points(
503 &self,
504 strategies: &[WarmStartStrategy],
505 n_points: usize,
506 ) -> Vec<ParameterValue> {
507 let points_per_strategy = n_points / strategies.len().max(1);
508 let mut all_points = Vec::new();
509
510 for strategy in strategies {
511 let strategy_initializer = WarmStartInitializer::new(
512 self.history.clone(),
513 strategy.clone(),
514 self.config.clone(),
515 );
516 let points = strategy_initializer.generate_initial_points(points_per_strategy);
517 all_points.extend(points);
518 }
519
520 if all_points.len() < n_points {
522 let remaining = n_points - all_points.len();
523 let top_points = self.generate_top_k_points(remaining);
524 all_points.extend(top_points);
525 }
526
527 all_points.into_iter().take(n_points).collect()
528 }
529
530 pub fn update_history(&mut self, record: EvaluationRecord) {
532 self.history.add_record(record);
533 }
534
535 pub fn statistics(&self) -> OptimizationStatistics {
537 self.history.get_statistics()
538 }
539
540 pub fn history(&self) -> &OptimizationHistory {
542 &self.history
543 }
544
545 pub fn history_mut(&mut self) -> &mut OptimizationHistory {
547 &mut self.history
548 }
549}
550
551pub struct TransferLearning {
553 problem_histories: HashMap<String, OptimizationHistory>,
555 similarity_threshold: f64,
557}
558
559impl TransferLearning {
560 pub fn new(similarity_threshold: f64) -> Self {
561 Self {
562 problem_histories: HashMap::new(),
563 similarity_threshold,
564 }
565 }
566
567 pub fn add_problem_history(&mut self, problem_id: String, history: OptimizationHistory) {
569 self.problem_histories.insert(problem_id, history);
570 }
571
572 pub fn get_transfer_recommendations(
574 &self,
575 problem_signature: &str,
576 n_recommendations: usize,
577 ) -> Vec<ParameterValue> {
578 let mut all_recommendations = Vec::new();
579
580 for history in self.problem_histories.values() {
581 if let Some(hist_signature) = history.problem_signature() {
582 let similarity =
583 self.calculate_problem_similarity(problem_signature, hist_signature);
584
585 if similarity >= self.similarity_threshold {
586 let top_configs = history.get_top_k(n_recommendations);
587 for config in top_configs {
588 all_recommendations
589 .push((config.parameters.clone(), config.score * similarity));
590 }
591 }
592 }
593 }
594
595 all_recommendations
597 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
598
599 all_recommendations
600 .into_iter()
601 .take(n_recommendations)
602 .map(|(params, _)| params)
603 .collect()
604 }
605
606 fn calculate_problem_similarity(&self, sig1: &str, sig2: &str) -> f64 {
607 let words1: std::collections::HashSet<&str> = sig1.split_whitespace().collect();
610 let words2: std::collections::HashSet<&str> = sig2.split_whitespace().collect();
611
612 let intersection = words1.intersection(&words2).count();
613 let union = words1.union(&words2).count();
614
615 if union == 0 {
616 0.0
617 } else {
618 intersection as f64 / union as f64
619 }
620 }
621}
622
623fn parameter_distance(p1: &ParameterValue, p2: &ParameterValue) -> f64 {
626 match (p1, p2) {
627 (ParameterValue::Float(v1), ParameterValue::Float(v2)) => (v1 - v2).abs(),
628 (ParameterValue::Integer(v1), ParameterValue::Integer(v2)) => (*v1 - *v2).abs() as f64,
629 (ParameterValue::String(s1), ParameterValue::String(s2)) => {
630 if s1 == s2 {
631 0.0
632 } else {
633 1.0
634 }
635 }
636 _ => 1.0, }
638}
639
640fn statistical_std(values: &[f64]) -> f64 {
641 if values.len() <= 1 {
642 return 0.0;
643 }
644
645 let mean = values.iter().sum::<f64>() / values.len() as f64;
646 let variance =
647 values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
648
649 variance.sqrt()
650}
651
652#[allow(non_snake_case)]
653#[cfg(test)]
654mod tests {
655 use super::*;
656
657 #[test]
658 fn test_evaluation_record() {
659 let params = ParameterValue::Float(1.5);
660 let record = EvaluationRecord::new(params.clone(), 0.85)
661 .with_cv_scores(vec![0.8, 0.85, 0.9])
662 .with_duration(1500)
663 .with_metadata("algorithm".to_string(), "random_forest".to_string());
664
665 assert_eq!(record.score, 0.85);
666 assert_eq!(record.parameters, params);
667 assert!(record.cv_scores.is_some());
668 assert!(record.score_std.is_some());
669 assert_eq!(record.duration_ms, Some(1500));
670 assert_eq!(
671 record.metadata.get("algorithm"),
672 Some(&"random_forest".to_string())
673 );
674 }
675
676 #[test]
677 fn test_optimization_history() {
678 let config = WarmStartConfig::default();
679 let mut history = OptimizationHistory::new(config);
680
681 let record1 = EvaluationRecord::new(ParameterValue::Float(1.0), 0.8);
683 let record2 = EvaluationRecord::new(ParameterValue::Float(2.0), 0.9);
684 let record3 = EvaluationRecord::new(ParameterValue::Float(3.0), 0.7);
685
686 history.add_record(record1);
687 history.add_record(record2);
688 history.add_record(record3);
689
690 let top_2 = history.get_top_k(2);
692 assert_eq!(top_2.len(), 2);
693 assert_eq!(top_2[0].score, 0.9); assert_eq!(top_2[1].score, 0.8); let stats = history.get_statistics();
698 assert_eq!(stats.total_evaluations, 3);
699 assert_eq!(stats.best_score, 0.9);
700 assert_eq!(stats.worst_score, 0.7);
701 assert!((stats.mean_score - 0.8).abs() < 1e-6);
702 }
703
704 #[test]
705 fn test_warm_start_initializer() {
706 let config = WarmStartConfig::default();
707 let mut history = OptimizationHistory::new(config.clone());
708
709 for i in 0..10 {
711 let score = 0.5 + (i as f64) * 0.05; let record = EvaluationRecord::new(ParameterValue::Float(i as f64), score);
713 history.add_record(record);
714 }
715
716 let initializer = WarmStartInitializer::new(history, WarmStartStrategy::TopK(5), config);
717
718 let initial_points = initializer.generate_initial_points(3);
719 assert_eq!(initial_points.len(), 3);
720
721 if let ParameterValue::Float(val) = &initial_points[0] {
723 assert!(*val >= 7.0); }
725 }
726
727 #[test]
728 fn test_parameter_distance() {
729 let p1 = ParameterValue::Float(1.0);
730 let p2 = ParameterValue::Float(2.0);
731 let p3 = ParameterValue::String("test".to_string());
732 let p4 = ParameterValue::String("test".to_string());
733
734 assert_eq!(parameter_distance(&p1, &p2), 1.0);
735 assert_eq!(parameter_distance(&p3, &p4), 0.0);
736 assert_eq!(parameter_distance(&p1, &p3), 1.0); }
738
739 #[test]
740 fn test_transfer_learning() {
741 let mut transfer = TransferLearning::new(0.5);
742
743 let config = WarmStartConfig::default();
744 let mut history1 = OptimizationHistory::new(config.clone());
745 history1.set_problem_signature("classification tree depth".to_string());
746
747 let record = EvaluationRecord::new(ParameterValue::Float(10.0), 0.95);
748 history1.add_record(record);
749
750 transfer.add_problem_history("problem1".to_string(), history1);
751
752 let recommendations =
753 transfer.get_transfer_recommendations("classification tree depth optimization", 2);
754
755 assert!(!recommendations.is_empty());
756 }
757
758 #[test]
759 #[cfg(feature = "serde")]
760 fn test_json_serialization() {
761 let config = WarmStartConfig::default();
762 let mut history = OptimizationHistory::new(config.clone());
763
764 let record = EvaluationRecord::new(ParameterValue::Float(1.5), 0.85)
765 .with_cv_scores(vec![0.8, 0.85, 0.9]);
766
767 history.add_record(record);
768
769 let json = history.export_to_json().unwrap();
770 assert!(json.contains("1.5"));
771 assert!(json.contains("0.85"));
772
773 let mut new_history = OptimizationHistory::new(config);
774 new_history.import_from_json(&json).unwrap();
775
776 assert_eq!(new_history.records.len(), 1);
777 assert_eq!(new_history.records[0].score, 0.85);
778 }
779}