Skip to main content

trustformers_training/hyperopt/
mod.rs

1//! Automated Hyperparameter Tuning Framework
2//!
3//! This module provides a comprehensive hyperparameter optimization framework
4//! for TrustformeRS models, supporting multiple search strategies and automated
5//! experiment tracking.
6
7pub mod auto_tuner;
8pub mod efficiency;
9pub mod examples;
10pub mod sampler;
11pub mod search_space;
12pub mod strategies;
13pub mod surrogate_models;
14pub mod trial;
15pub mod tuner;
16
17use serde::{Deserialize, Serialize};
18
19pub use auto_tuner::{
20    AcquisitionFunction as AutoTunerAcquisitionFunction, AutomatedHyperparameterTuner,
21    BayesianOptimizationTuner, GaussianProcess, HyperparameterConfig, HyperparameterSpace,
22    HyperparameterTuner as AutoTunerHyperparameterTuner, Kernel,
23    OptimizationDirection as AutoTunerOptimizationDirection, ParameterConstraint, ParameterScale,
24    ParameterSpec, ParameterValue as AutoTunerParameterValue, RandomSearchTuner,
25    ResourceAllocation as AutoTunerResourceAllocation, ResourceSharingStrategy, SearchAlgorithm,
26    TuningConfig, TuningResult,
27};
28pub use efficiency::{
29    AcquisitionFunction, AcquisitionFunctionType, AdvancedEarlyStoppingConfig,
30    ArmGenerationStrategy, ArmStatistics, BanditAlgorithm, BanditConfig, BanditOptimizer,
31    EarlyStoppingStrategy, EvaluationJob, EvaluationResult, ExplorationStrategy,
32    FaultToleranceConfig, GPUAllocation, JobStatus, KernelType, LoadBalancer,
33    ParallelEvaluationConfig, ParallelEvaluator, ParallelStrategy, PriorityLevel,
34    ResourceAllocation, ResourceUsage, RewardFunction, SurrogateConfig, SurrogateModel,
35    SurrogateModelType, SurrogateOptimizer, WarmStartConfig, WarmStartDataSource,
36    WarmStartStrategy,
37};
38pub use examples::{
39    computer_vision_objective, language_modeling_objective, params_to_training_args,
40    HyperparameterOptimizer, HyperparameterStudy, MultiStrategyOptimizer,
41};
42pub use sampler::{GPSampler, RandomSampler, Sampler, SamplerConfig, TPESampler};
43pub use search_space::{
44    CategoricalParameter, ContinuousParameter, DiscreteParameter, HyperParameter, LogParameter,
45    ParameterValue, SearchSpace,
46};
47pub use strategies::{
48    BayesianOptimization, GridSearch, HalvingStrategy, Hyperband, PBTConfig, PBTMember, PBTStats,
49    PopulationBasedTraining, RandomSearch, SearchStrategy, SuccessiveHalving,
50};
51pub use surrogate_models::{
52    create_acquisition_function, create_surrogate_model, ExpectedImprovement,
53    SimpleGaussianProcess, UpperConfidenceBound,
54};
55pub use trial::{Trial, TrialHistory, TrialMetrics, TrialResult, TrialState};
56pub use tuner::{HyperparameterTuner, OptimizationDirection, StudyStatistics, TunerConfig};
57
58/// Direction for optimization (minimize or maximize the objective)
59#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
60pub enum Direction {
61    /// Minimize the objective value (e.g., loss)
62    Minimize,
63    /// Maximize the objective value (e.g., accuracy)
64    Maximize,
65}
66
67/// Result of a hyperparameter optimization study
68#[derive(Debug, Clone)]
69pub struct OptimizationResult {
70    /// Best trial found
71    pub best_trial: Trial,
72    /// All trials run during the study
73    pub trials: Vec<Trial>,
74    /// Number of trials that completed successfully
75    pub completed_trials: usize,
76    /// Number of trials that failed
77    pub failed_trials: usize,
78    /// Total time spent on optimization
79    pub total_duration: std::time::Duration,
80    /// Statistics about the study
81    pub statistics: StudyStatistics,
82}
83
84/// Configuration for early stopping of trials
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct EarlyStoppingConfig {
87    /// Patience: number of evaluation steps to wait before stopping
88    pub patience: usize,
89    /// Minimum improvement threshold
90    pub min_delta: f64,
91    /// Whether to restore best weights when stopping
92    pub restore_best_weights: bool,
93}
94
95/// Configuration for pruning unpromising trials
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct PruningConfig {
98    /// Strategy to use for pruning
99    pub strategy: PruningStrategy,
100    /// Minimum number of steps before pruning can occur
101    pub min_steps: usize,
102    /// Percentile threshold for pruning (e.g., 0.5 = median)
103    pub percentile: f64,
104}
105
106/// Strategy for pruning trials
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub enum PruningStrategy {
109    /// No pruning
110    None,
111    /// Median pruning: stop if performance is below median
112    Median,
113    /// Percentile pruning: stop if performance is below specified percentile
114    Percentile(f64),
115    /// Successive halving: eliminate worst performing trials at each stage
116    SuccessiveHalving,
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122
123    #[test]
124    fn test_direction() {
125        assert_eq!(Direction::Minimize, Direction::Minimize);
126        assert_ne!(Direction::Minimize, Direction::Maximize);
127    }
128
129    #[test]
130    fn test_pruning_strategy() {
131        let strategy = PruningStrategy::Percentile(0.25);
132        match strategy {
133            PruningStrategy::Percentile(p) => assert_eq!(p, 0.25),
134            _ => panic!("Expected Percentile strategy"),
135        }
136    }
137}