1use scirs2_core::ndarray::Array2;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone)]
12pub struct TuningConfig {
13 pub strategy: SearchStrategy,
15 pub metric: EvaluationMetric,
17 pub cv_config: CrossValidationConfig,
19 pub max_evaluations: usize,
21 pub early_stopping: Option<EarlyStoppingConfig>,
23 pub random_seed: Option<u64>,
25 pub parallel_config: Option<ParallelConfig>,
27 pub resource_constraints: ResourceConstraints,
29}
30
31impl Default for TuningConfig {
32 fn default() -> Self {
33 Self {
34 strategy: SearchStrategy::RandomSearch { n_trials: 50 },
35 metric: EvaluationMetric::SilhouetteScore,
36 cv_config: CrossValidationConfig::default(),
37 max_evaluations: 100,
38 early_stopping: None,
39 random_seed: Some(42),
40 parallel_config: None,
41 resource_constraints: ResourceConstraints::default(),
42 }
43 }
44}
45
46#[derive(Debug, Clone)]
48pub enum SearchStrategy {
49 GridSearch,
51 RandomSearch { n_trials: usize },
53 BayesianOptimization {
55 n_initial_points: usize,
56 acquisition_function: AcquisitionFunction,
57 },
58 AdaptiveSearch {
60 initial_strategy: Box<SearchStrategy>,
61 adaptation_frequency: usize,
62 },
63 MultiObjective {
65 objectives: Vec<EvaluationMetric>,
66 strategy: Box<SearchStrategy>,
67 },
68 EnsembleSearch {
70 strategies: Vec<SearchStrategy>,
71 weights: Vec<f64>,
72 },
73 EvolutionarySearch {
75 population_size: usize,
76 n_generations: usize,
77 mutation_rate: f64,
78 crossover_rate: f64,
79 },
80 SMBO {
82 surrogate_model: SurrogateModel,
83 acquisition_function: AcquisitionFunction,
84 },
85}
86
87#[derive(Debug, Clone)]
89pub enum AcquisitionFunction {
90 ExpectedImprovement,
92 UpperConfidenceBound { beta: f64 },
94 ProbabilityOfImprovement,
96 EntropySearch,
98 KnowledgeGradient,
100 ThompsonSampling,
102}
103
104#[derive(Debug, Clone)]
106pub enum SurrogateModel {
107 GaussianProcess { kernel: KernelType, noise: f64 },
109 RandomForest {
111 n_trees: usize,
112 max_depth: Option<usize>,
113 },
114 GradientBoosting {
116 n_estimators: usize,
117 learning_rate: f64,
118 },
119}
120
121#[derive(Debug, Clone)]
123pub enum KernelType {
124 RBF { length_scale: f64 },
126 Matern { length_scale: f64, nu: f64 },
128 Linear,
130 Polynomial { degree: usize },
132}
133
134#[derive(Debug, Clone, PartialEq, Eq)]
136pub enum EvaluationMetric {
137 SilhouetteScore,
139 DaviesBouldinIndex,
141 CalinskiHarabaszIndex,
143 Inertia,
145 AdjustedRandIndex,
147 Custom(String),
149 EnsembleConsensus,
151 Stability,
153 MutualInformation,
155}
156
157#[derive(Debug, Clone)]
159pub struct CrossValidationConfig {
160 pub n_folds: usize,
162 pub validation_ratio: f64,
164 pub strategy: CVStrategy,
166 pub shuffle: bool,
168}
169
170impl Default for CrossValidationConfig {
171 fn default() -> Self {
172 Self {
173 n_folds: 5,
174 validation_ratio: 0.2,
175 strategy: CVStrategy::KFold,
176 shuffle: true,
177 }
178 }
179}
180
181#[derive(Debug, Clone)]
183pub enum CVStrategy {
184 KFold,
186 StratifiedKFold,
188 TimeSeriesSplit,
190 Bootstrap { n_bootstrap: usize },
192 EnsembleCV { strategies: Vec<CVStrategy> },
194 MonteCarlo { n_splits: usize, test_size: f64 },
196 NestedCV {
198 outer_folds: usize,
199 inner_folds: usize,
200 },
201}
202
203#[derive(Debug, Clone)]
205pub struct EarlyStoppingConfig {
206 pub patience: usize,
208 pub min_improvement: f64,
210 pub evaluation_frequency: usize,
212}
213
214#[derive(Debug, Clone)]
216pub struct ParallelConfig {
217 pub n_workers: usize,
219 pub batch_size: usize,
221 pub load_balancing: LoadBalancingStrategy,
223}
224
225#[derive(Debug, Clone)]
227pub enum LoadBalancingStrategy {
228 RoundRobin,
230 WorkStealing,
232 Dynamic,
234}
235
236#[derive(Debug, Clone, Default)]
238pub struct ResourceConstraints {
239 pub max_memory_per_evaluation: Option<usize>,
241 pub max_time_per_evaluation: Option<f64>,
243 pub max_total_time: Option<f64>,
245}
246
247#[derive(Debug, Clone, Serialize, Deserialize)]
249pub enum HyperParameter {
250 Integer { min: i64, max: i64 },
252 Float { min: f64, max: f64 },
254 Categorical { choices: Vec<String> },
256 Boolean,
258 LogUniform { min: f64, max: f64 },
260 IntegerChoices { choices: Vec<i64> },
262}
263
264#[derive(Debug, Clone)]
266pub struct SearchSpace {
267 pub parameters: HashMap<String, HyperParameter>,
269 pub constraints: Vec<ParameterConstraint>,
271}
272
273#[derive(Debug, Clone)]
275pub enum ParameterConstraint {
276 Conditional {
278 condition: String,
279 constraint: Box<ParameterConstraint>,
280 },
281 Range {
283 parameter: String,
284 min: f64,
285 max: f64,
286 },
287 Dependency {
289 dependent: String,
290 dependency: String,
291 relationship: DependencyRelationship,
292 },
293}
294
295#[derive(Debug, Clone)]
297pub enum DependencyRelationship {
298 Linear { k: f64, c: f64 },
300 Proportional { ratio: f64 },
302 Custom(String),
304}
305
306#[derive(Debug, Clone)]
308pub struct EvaluationResult {
309 pub parameters: HashMap<String, f64>,
311 pub score: f64,
313 pub additional_metrics: HashMap<String, f64>,
315 pub evaluation_time: f64,
317 pub memory_usage: Option<usize>,
319 pub cv_scores: Vec<f64>,
321 pub cv_std: f64,
323 pub metadata: HashMap<String, String>,
325}
326
327#[derive(Debug, Clone)]
329pub struct TuningResult {
330 pub best_parameters: HashMap<String, f64>,
332 pub best_score: f64,
334 pub evaluation_history: Vec<EvaluationResult>,
336 pub convergence_info: ConvergenceInfo,
338 pub exploration_stats: ExplorationStats,
340 pub total_time: f64,
342 pub ensemble_results: Option<EnsembleResults>,
344 pub pareto_front: Option<Vec<HashMap<String, f64>>>,
346}
347
348#[derive(Debug, Clone)]
350pub struct EnsembleResults {
351 pub member_results: Vec<TuningResult>,
353 pub consensus_parameters: HashMap<String, f64>,
355 pub agreement_score: f64,
357 pub diversity_metrics: HashMap<String, f64>,
359}
360
361#[derive(Debug, Clone)]
363pub struct BayesianState {
364 pub observations: Vec<(HashMap<String, f64>, f64)>,
366 pub gp_mean: Option<f64>,
368 pub gp_covariance: Option<Array2<f64>>,
370 pub acquisition_values: Vec<f64>,
372 pub parameter_names: Vec<String>,
374 pub gp_hyperparameters: GpHyperparameters,
376 pub noise_level: f64,
378 pub currentbest: f64,
380}
381
382#[derive(Debug, Clone)]
384pub struct GpHyperparameters {
385 pub length_scales: Vec<f64>,
387 pub signal_variance: f64,
389 pub noise_variance: f64,
391 pub kernel_type: KernelType,
393}
394
395#[derive(Debug, Clone)]
397pub struct ConvergenceInfo {
398 pub converged: bool,
400 pub convergence_iteration: Option<usize>,
402 pub stopping_reason: StoppingReason,
404}
405
406#[derive(Debug, Clone)]
408pub enum StoppingReason {
409 MaxEvaluations,
411 EarlyStopping,
413 TimeLimit,
415 Convergence,
417 UserInterruption,
419 ResourceConstraints,
421}
422
423#[derive(Debug, Clone)]
425pub struct ExplorationStats {
426 pub coverage: f64,
428 pub parameter_distributions: HashMap<String, Vec<f64>>,
430 pub parameter_importance: HashMap<String, f64>,
432}