scirs2_cluster/tuning/
config.rs

1//! Hyperparameter tuning configuration types and structures
2//!
3//! This module contains all configuration structures and enums used for
4//! hyperparameter optimization in clustering algorithms.
5
6use scirs2_core::ndarray::Array2;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Hyperparameter tuning configuration
11#[derive(Debug, Clone)]
12pub struct TuningConfig {
13    /// Search strategy to use
14    pub strategy: SearchStrategy,
15    /// Evaluation metric for optimization
16    pub metric: EvaluationMetric,
17    /// Cross-validation configuration
18    pub cv_config: CrossValidationConfig,
19    /// Maximum number of evaluations
20    pub max_evaluations: usize,
21    /// Early stopping criteria
22    pub early_stopping: Option<EarlyStoppingConfig>,
23    /// Random seed for reproducible results
24    pub random_seed: Option<u64>,
25    /// Parallel evaluation configuration
26    pub parallel_config: Option<ParallelConfig>,
27    /// Resource constraints
28    pub resource_constraints: ResourceConstraints,
29}
30
31impl Default for TuningConfig {
32    fn default() -> Self {
33        Self {
34            strategy: SearchStrategy::RandomSearch { n_trials: 50 },
35            metric: EvaluationMetric::SilhouetteScore,
36            cv_config: CrossValidationConfig::default(),
37            max_evaluations: 100,
38            early_stopping: None,
39            random_seed: Some(42),
40            parallel_config: None,
41            resource_constraints: ResourceConstraints::default(),
42        }
43    }
44}
45
46/// Hyperparameter search strategies
47#[derive(Debug, Clone)]
48pub enum SearchStrategy {
49    /// Exhaustive grid search
50    GridSearch,
51    /// Random search with specified number of trials
52    RandomSearch { n_trials: usize },
53    /// Bayesian optimization using Gaussian processes
54    BayesianOptimization {
55        n_initial_points: usize,
56        acquisition_function: AcquisitionFunction,
57    },
58    /// Adaptive search that adjusts based on results
59    AdaptiveSearch {
60        initial_strategy: Box<SearchStrategy>,
61        adaptation_frequency: usize,
62    },
63    /// Multi-objective optimization
64    MultiObjective {
65        objectives: Vec<EvaluationMetric>,
66        strategy: Box<SearchStrategy>,
67    },
68    /// Ensemble search combining multiple strategies
69    EnsembleSearch {
70        strategies: Vec<SearchStrategy>,
71        weights: Vec<f64>,
72    },
73    /// Evolutionary search strategy
74    EvolutionarySearch {
75        population_size: usize,
76        n_generations: usize,
77        mutation_rate: f64,
78        crossover_rate: f64,
79    },
80    /// Sequential model-based optimization
81    SMBO {
82        surrogate_model: SurrogateModel,
83        acquisition_function: AcquisitionFunction,
84    },
85}
86
87/// Acquisition functions for Bayesian optimization
88#[derive(Debug, Clone)]
89pub enum AcquisitionFunction {
90    /// Expected Improvement
91    ExpectedImprovement,
92    /// Upper Confidence Bound
93    UpperConfidenceBound { beta: f64 },
94    /// Probability of Improvement
95    ProbabilityOfImprovement,
96    /// Entropy Search
97    EntropySearch,
98    /// Knowledge Gradient
99    KnowledgeGradient,
100    /// Thompson Sampling
101    ThompsonSampling,
102}
103
104/// Surrogate models for SMBO
105#[derive(Debug, Clone)]
106pub enum SurrogateModel {
107    /// Gaussian Process
108    GaussianProcess { kernel: KernelType, noise: f64 },
109    /// Random Forest
110    RandomForest {
111        n_trees: usize,
112        max_depth: Option<usize>,
113    },
114    /// Gradient Boosting
115    GradientBoosting {
116        n_estimators: usize,
117        learning_rate: f64,
118    },
119}
120
121/// Kernel types for Gaussian Processes
122#[derive(Debug, Clone)]
123pub enum KernelType {
124    /// Radial Basis Function (RBF)
125    RBF { length_scale: f64 },
126    /// Matérn kernel
127    Matern { length_scale: f64, nu: f64 },
128    /// Linear kernel
129    Linear,
130    /// Polynomial kernel
131    Polynomial { degree: usize },
132}
133
134/// Evaluation metrics for hyperparameter optimization
135#[derive(Debug, Clone, PartialEq, Eq)]
136pub enum EvaluationMetric {
137    /// Silhouette coefficient (higher is better)
138    SilhouetteScore,
139    /// Davies-Bouldin index (lower is better)
140    DaviesBouldinIndex,
141    /// Calinski-Harabasz index (higher is better)
142    CalinskiHarabaszIndex,
143    /// Within-cluster sum of squares (lower is better)
144    Inertia,
145    /// Adjusted Rand Index (for labeled data)
146    AdjustedRandIndex,
147    /// Custom metric
148    Custom(String),
149    /// Ensemble consensus score
150    EnsembleConsensus,
151    /// Stability-based metrics
152    Stability,
153    /// Information-theoretic metrics
154    MutualInformation,
155}
156
157/// Cross-validation configuration
158#[derive(Debug, Clone)]
159pub struct CrossValidationConfig {
160    /// Number of folds
161    pub n_folds: usize,
162    /// Fraction of data to use for validation
163    pub validation_ratio: f64,
164    /// Strategy for cross-validation
165    pub strategy: CVStrategy,
166    /// Shuffle data before splitting
167    pub shuffle: bool,
168}
169
170impl Default for CrossValidationConfig {
171    fn default() -> Self {
172        Self {
173            n_folds: 5,
174            validation_ratio: 0.2,
175            strategy: CVStrategy::KFold,
176            shuffle: true,
177        }
178    }
179}
180
181/// Cross-validation strategies
182#[derive(Debug, Clone)]
183pub enum CVStrategy {
184    /// K-fold cross-validation
185    KFold,
186    /// Stratified K-fold (for labeled data)
187    StratifiedKFold,
188    /// Time series split (preserves temporal order)
189    TimeSeriesSplit,
190    /// Bootstrap cross-validation
191    Bootstrap { n_bootstrap: usize },
192    /// Ensemble cross-validation (multiple CV strategies)
193    EnsembleCV { strategies: Vec<CVStrategy> },
194    /// Monte Carlo cross-validation
195    MonteCarlo { n_splits: usize, test_size: f64 },
196    /// Nested cross-validation
197    NestedCV {
198        outer_folds: usize,
199        inner_folds: usize,
200    },
201}
202
203/// Early stopping configuration
204#[derive(Debug, Clone)]
205pub struct EarlyStoppingConfig {
206    /// Patience (number of evaluations without improvement)
207    pub patience: usize,
208    /// Minimum improvement required
209    pub min_improvement: f64,
210    /// Evaluation frequency
211    pub evaluation_frequency: usize,
212}
213
214/// Parallel evaluation configuration
215#[derive(Debug, Clone)]
216pub struct ParallelConfig {
217    /// Number of parallel workers
218    pub n_workers: usize,
219    /// Batch size for parallel evaluation
220    pub batch_size: usize,
221    /// Load balancing strategy
222    pub load_balancing: LoadBalancingStrategy,
223}
224
225/// Load balancing strategies for parallel evaluation
226#[derive(Debug, Clone)]
227pub enum LoadBalancingStrategy {
228    /// Round-robin assignment
229    RoundRobin,
230    /// Work stealing
231    WorkStealing,
232    /// Dynamic load balancing
233    Dynamic,
234}
235
236/// Resource constraints for hyperparameter tuning
237#[derive(Debug, Clone, Default)]
238pub struct ResourceConstraints {
239    /// Maximum memory usage per evaluation (bytes)
240    pub max_memory_per_evaluation: Option<usize>,
241    /// Maximum time per evaluation (seconds)
242    pub max_time_per_evaluation: Option<f64>,
243    /// Maximum total tuning time (seconds)
244    pub max_total_time: Option<f64>,
245}
246
247/// Hyperparameter specification
248#[derive(Debug, Clone, Serialize, Deserialize)]
249pub enum HyperParameter {
250    /// Integer parameter with range [min, max]
251    Integer { min: i64, max: i64 },
252    /// Float parameter with range [min, max]
253    Float { min: f64, max: f64 },
254    /// Categorical parameter with choices
255    Categorical { choices: Vec<String> },
256    /// Boolean parameter
257    Boolean,
258    /// Log-uniform distribution for float parameters
259    LogUniform { min: f64, max: f64 },
260    /// Discrete choices for integer parameters
261    IntegerChoices { choices: Vec<i64> },
262}
263
264/// Hyperparameter search space for clustering algorithms
265#[derive(Debug, Clone)]
266pub struct SearchSpace {
267    /// Parameters to optimize
268    pub parameters: HashMap<String, HyperParameter>,
269    /// Algorithm-specific constraints
270    pub constraints: Vec<ParameterConstraint>,
271}
272
273/// Parameter constraints for interdependent hyperparameters
274#[derive(Debug, Clone)]
275pub enum ParameterConstraint {
276    /// Conditional constraint: if condition then constraint
277    Conditional {
278        condition: String,
279        constraint: Box<ParameterConstraint>,
280    },
281    /// Range constraint: parameter must be in range
282    Range {
283        parameter: String,
284        min: f64,
285        max: f64,
286    },
287    /// Dependency constraint: parameter A depends on parameter B
288    Dependency {
289        dependent: String,
290        dependency: String,
291        relationship: DependencyRelationship,
292    },
293}
294
295/// Dependency relationships between parameters
296#[derive(Debug, Clone)]
297pub enum DependencyRelationship {
298    /// Linear relationship: A = k * B + c
299    Linear { k: f64, c: f64 },
300    /// Proportional: A <= ratio * B
301    Proportional { ratio: f64 },
302    /// Custom function
303    Custom(String),
304}
305
306/// Hyperparameter evaluation result
307#[derive(Debug, Clone)]
308pub struct EvaluationResult {
309    /// Parameter values used
310    pub parameters: HashMap<String, f64>,
311    /// Primary metric score
312    pub score: f64,
313    /// Additional metrics
314    pub additional_metrics: HashMap<String, f64>,
315    /// Evaluation time (seconds)
316    pub evaluation_time: f64,
317    /// Memory usage (bytes)
318    pub memory_usage: Option<usize>,
319    /// Cross-validation scores
320    pub cv_scores: Vec<f64>,
321    /// Standard deviation of CV scores
322    pub cv_std: f64,
323    /// Algorithm-specific metadata
324    pub metadata: HashMap<String, String>,
325}
326
327/// Tuning results
328#[derive(Debug, Clone)]
329pub struct TuningResult {
330    /// Best parameter configuration found
331    pub best_parameters: HashMap<String, f64>,
332    /// Best score achieved
333    pub best_score: f64,
334    /// All evaluation results
335    pub evaluation_history: Vec<EvaluationResult>,
336    /// Convergence information
337    pub convergence_info: ConvergenceInfo,
338    /// Search space exploration statistics
339    pub exploration_stats: ExplorationStats,
340    /// Total tuning time
341    pub total_time: f64,
342    /// Ensemble results (if ensemble method was used)
343    pub ensemble_results: Option<EnsembleResults>,
344    /// Pareto front (for multi-objective optimization)
345    pub pareto_front: Option<Vec<HashMap<String, f64>>>,
346}
347
348/// Results from ensemble tuning
349#[derive(Debug, Clone)]
350pub struct EnsembleResults {
351    /// Results from each ensemble member
352    pub member_results: Vec<TuningResult>,
353    /// Consensus best parameters
354    pub consensus_parameters: HashMap<String, f64>,
355    /// Agreement score between ensemble members
356    pub agreement_score: f64,
357    /// Diversity metrics
358    pub diversity_metrics: HashMap<String, f64>,
359}
360
361/// Bayesian optimization state
362#[derive(Debug, Clone)]
363pub struct BayesianState {
364    /// Observed parameters and scores
365    pub observations: Vec<(HashMap<String, f64>, f64)>,
366    /// Gaussian process mean function
367    pub gp_mean: Option<f64>,
368    /// Gaussian process covariance matrix
369    pub gp_covariance: Option<Array2<f64>>,
370    /// Acquisition function values
371    pub acquisition_values: Vec<f64>,
372    /// Parameter names for consistent ordering
373    pub parameter_names: Vec<String>,
374    /// GP hyperparameters
375    pub gp_hyperparameters: GpHyperparameters,
376    /// Noise level
377    pub noise_level: f64,
378    /// Current best observed value
379    pub currentbest: f64,
380}
381
382/// Gaussian Process hyperparameters
383#[derive(Debug, Clone)]
384pub struct GpHyperparameters {
385    /// Length scales for each dimension
386    pub length_scales: Vec<f64>,
387    /// Signal variance
388    pub signal_variance: f64,
389    /// Noise variance
390    pub noise_variance: f64,
391    /// Kernel type
392    pub kernel_type: KernelType,
393}
394
395/// Convergence information
396#[derive(Debug, Clone)]
397pub struct ConvergenceInfo {
398    /// Whether tuning converged
399    pub converged: bool,
400    /// Iteration at which convergence was detected
401    pub convergence_iteration: Option<usize>,
402    /// Reason for stopping
403    pub stopping_reason: StoppingReason,
404}
405
406/// Reasons for stopping hyperparameter tuning
407#[derive(Debug, Clone)]
408pub enum StoppingReason {
409    /// Maximum evaluations reached
410    MaxEvaluations,
411    /// Early stopping triggered
412    EarlyStopping,
413    /// Time limit exceeded
414    TimeLimit,
415    /// Convergence achieved
416    Convergence,
417    /// User interruption
418    UserInterruption,
419    /// Resource constraints
420    ResourceConstraints,
421}
422
423/// Search space exploration statistics
424#[derive(Debug, Clone)]
425pub struct ExplorationStats {
426    /// Parameter space coverage
427    pub coverage: f64,
428    /// Distribution of parameter values explored
429    pub parameter_distributions: HashMap<String, Vec<f64>>,
430    /// Correlation between parameters and performance
431    pub parameter_importance: HashMap<String, f64>,
432}