1use serde::{Deserialize, Serialize};
2use std::path::PathBuf;
3use trustformers_core::errors::{invalid_config, Result};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct TrainingArguments {
8 pub output_dir: PathBuf,
10
11 pub overwrite_output_dir: bool,
13
14 pub do_eval: bool,
16
17 pub do_predict: bool,
19
20 pub warmup_steps: usize,
22
23 pub warmup_ratio: f32,
25
26 pub learning_rate: f32,
28
29 pub weight_decay: f32,
31
32 pub adam_beta1: f32,
34
35 pub adam_beta2: f32,
37
38 pub adam_epsilon: f32,
40
41 pub max_grad_norm: f32,
43
44 pub num_train_epochs: f32,
46
47 pub max_steps: Option<usize>,
49
50 pub gradient_accumulation_steps: usize,
52
53 pub per_device_train_batch_size: usize,
55
56 pub per_device_eval_batch_size: usize,
58
59 pub dataloader_num_workers: usize,
61
62 pub dataloader_pin_memory: bool,
64
65 pub save_steps: usize,
67
68 pub save_total_limit: Option<usize>,
70
71 pub logging_steps: usize,
73
74 pub eval_steps: usize,
76
77 pub eval_at_end: bool,
79
80 pub seed: u64,
82
83 pub fp16: bool,
85
86 pub bf16: bool,
88
89 pub metric_for_best_model: Option<String>,
91
92 pub greater_is_better: Option<bool>,
94
95 pub early_stopping_patience: Option<usize>,
97
98 pub early_stopping_threshold: Option<f32>,
100
101 pub load_best_model_at_end: bool,
103
104 pub evaluation_strategy: EvaluationStrategy,
106
107 pub save_strategy: SaveStrategy,
109
110 pub logging_dir: Option<PathBuf>,
112
113 pub do_train: bool,
115
116 pub resume_from_checkpoint: Option<PathBuf>,
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
121pub enum EvaluationStrategy {
122 No,
124 Steps,
126 Epoch,
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
131pub enum SaveStrategy {
132 No,
134 Steps,
136 Epoch,
138}
139
140impl Default for TrainingArguments {
141 fn default() -> Self {
142 Self {
143 output_dir: PathBuf::from("./results"),
144 overwrite_output_dir: false,
145 do_eval: false,
146 do_predict: false,
147 warmup_steps: 0,
148 warmup_ratio: 0.0,
149 learning_rate: 5e-5,
150 weight_decay: 0.0,
151 adam_beta1: 0.9,
152 adam_beta2: 0.999,
153 adam_epsilon: 1e-8,
154 max_grad_norm: 1.0,
155 num_train_epochs: 3.0,
156 max_steps: None,
157 gradient_accumulation_steps: 1,
158 per_device_train_batch_size: 8,
159 per_device_eval_batch_size: 8,
160 dataloader_num_workers: 0,
161 dataloader_pin_memory: false,
162 save_steps: 500,
163 save_total_limit: None,
164 logging_steps: 10,
165 eval_steps: 500,
166 eval_at_end: true,
167 seed: 42,
168 fp16: false,
169 bf16: false,
170 metric_for_best_model: None,
171 greater_is_better: None,
172 early_stopping_patience: None,
173 early_stopping_threshold: None,
174 load_best_model_at_end: false,
175 evaluation_strategy: EvaluationStrategy::No,
176 save_strategy: SaveStrategy::Steps,
177 logging_dir: None,
178 do_train: true,
179 resume_from_checkpoint: None,
180 }
181 }
182}
183
184impl TrainingArguments {
185 pub fn new(output_dir: impl Into<PathBuf>) -> Self {
187 Self {
188 output_dir: output_dir.into(),
189 ..Default::default()
190 }
191 }
192
193 pub fn get_total_steps(&self, num_examples: usize) -> usize {
195 if let Some(max_steps) = self.max_steps {
196 max_steps
197 } else {
198 let steps_per_epoch = num_examples.div_ceil(self.per_device_train_batch_size);
199 (self.num_train_epochs * steps_per_epoch as f32) as usize
200 }
201 }
202
203 pub fn get_effective_batch_size(&self) -> usize {
205 self.per_device_train_batch_size * self.gradient_accumulation_steps
206 }
207
208 pub fn get_warmup_steps(&self, total_steps: usize) -> usize {
210 if self.warmup_steps > 0 {
211 self.warmup_steps
212 } else {
213 (self.warmup_ratio * total_steps as f32) as usize
214 }
215 }
216
217 pub fn validate(&self) -> Result<()> {
219 if self.per_device_train_batch_size == 0 {
220 return Err(invalid_config(
221 "per_device_train_batch_size",
222 "must be greater than 0",
223 ));
224 }
225
226 if self.per_device_eval_batch_size == 0 {
227 return Err(invalid_config(
228 "per_device_eval_batch_size",
229 "must be greater than 0",
230 ));
231 }
232
233 if self.gradient_accumulation_steps == 0 {
234 return Err(invalid_config(
235 "gradient_accumulation_steps",
236 "must be greater than 0",
237 ));
238 }
239
240 if self.learning_rate <= 0.0 {
241 return Err(invalid_config("learning_rate", "must be positive"));
242 }
243
244 if self.num_train_epochs <= 0.0 && self.max_steps.is_none() {
245 return Err(invalid_config(
246 "training_schedule",
247 "either num_train_epochs or max_steps must be positive",
248 ));
249 }
250
251 Ok(())
252 }
253}