trustformers_training/hyperopt/
mod.rs1pub mod auto_tuner;
8pub mod efficiency;
9pub mod examples;
10pub mod sampler;
11pub mod search_space;
12pub mod strategies;
13pub mod surrogate_models;
14pub mod trial;
15pub mod tuner;
16
17use serde::{Deserialize, Serialize};
18
19pub use auto_tuner::{
20 AcquisitionFunction as AutoTunerAcquisitionFunction, AutomatedHyperparameterTuner,
21 BayesianOptimizationTuner, GaussianProcess, HyperparameterConfig, HyperparameterSpace,
22 HyperparameterTuner as AutoTunerHyperparameterTuner, Kernel,
23 OptimizationDirection as AutoTunerOptimizationDirection, ParameterConstraint, ParameterScale,
24 ParameterSpec, ParameterValue as AutoTunerParameterValue, RandomSearchTuner,
25 ResourceAllocation as AutoTunerResourceAllocation, ResourceSharingStrategy, SearchAlgorithm,
26 TuningConfig, TuningResult,
27};
28pub use efficiency::{
29 AcquisitionFunction, AcquisitionFunctionType, AdvancedEarlyStoppingConfig,
30 ArmGenerationStrategy, ArmStatistics, BanditAlgorithm, BanditConfig, BanditOptimizer,
31 EarlyStoppingStrategy, EvaluationJob, EvaluationResult, ExplorationStrategy,
32 FaultToleranceConfig, GPUAllocation, JobStatus, KernelType, LoadBalancer,
33 ParallelEvaluationConfig, ParallelEvaluator, ParallelStrategy, PriorityLevel,
34 ResourceAllocation, ResourceUsage, RewardFunction, SurrogateConfig, SurrogateModel,
35 SurrogateModelType, SurrogateOptimizer, WarmStartConfig, WarmStartDataSource,
36 WarmStartStrategy,
37};
38pub use examples::{
39 computer_vision_objective, language_modeling_objective, params_to_training_args,
40 HyperparameterOptimizer, HyperparameterStudy, MultiStrategyOptimizer,
41};
42pub use sampler::{GPSampler, RandomSampler, Sampler, SamplerConfig, TPESampler};
43pub use search_space::{
44 CategoricalParameter, ContinuousParameter, DiscreteParameter, HyperParameter, LogParameter,
45 ParameterValue, SearchSpace,
46};
47pub use strategies::{
48 BayesianOptimization, GridSearch, HalvingStrategy, Hyperband, PBTConfig, PBTMember, PBTStats,
49 PopulationBasedTraining, RandomSearch, SearchStrategy, SuccessiveHalving,
50};
51pub use surrogate_models::{
52 create_acquisition_function, create_surrogate_model, ExpectedImprovement,
53 SimpleGaussianProcess, UpperConfidenceBound,
54};
55pub use trial::{Trial, TrialHistory, TrialMetrics, TrialResult, TrialState};
56pub use tuner::{HyperparameterTuner, OptimizationDirection, StudyStatistics, TunerConfig};
57
58#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
60pub enum Direction {
61 Minimize,
63 Maximize,
65}
66
67#[derive(Debug, Clone)]
69pub struct OptimizationResult {
70 pub best_trial: Trial,
72 pub trials: Vec<Trial>,
74 pub completed_trials: usize,
76 pub failed_trials: usize,
78 pub total_duration: std::time::Duration,
80 pub statistics: StudyStatistics,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct EarlyStoppingConfig {
87 pub patience: usize,
89 pub min_delta: f64,
91 pub restore_best_weights: bool,
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct PruningConfig {
98 pub strategy: PruningStrategy,
100 pub min_steps: usize,
102 pub percentile: f64,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub enum PruningStrategy {
109 None,
111 Median,
113 Percentile(f64),
115 SuccessiveHalving,
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122
123 #[test]
124 fn test_direction() {
125 assert_eq!(Direction::Minimize, Direction::Minimize);
126 assert_ne!(Direction::Minimize, Direction::Maximize);
127 }
128
129 #[test]
130 fn test_pruning_strategy() {
131 let strategy = PruningStrategy::Percentile(0.25);
132 match strategy {
133 PruningStrategy::Percentile(p) => assert_eq!(p, 0.25),
134 _ => panic!("Expected Percentile strategy"),
135 }
136 }
137}