Skip to main content

trustformers_training/
training_args.rs

1use serde::{Deserialize, Serialize};
2use std::path::PathBuf;
3use trustformers_core::errors::{invalid_config, Result};
4
5/// Configuration arguments for training, closely matching HuggingFace's TrainingArguments
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct TrainingArguments {
8    /// The output directory where the model predictions and checkpoints will be written.
9    pub output_dir: PathBuf,
10
11    /// Whether to overwrite the content of the output directory.
12    pub overwrite_output_dir: bool,
13
14    /// Whether to do evaluation during training
15    pub do_eval: bool,
16
17    /// Whether to do prediction on the test set
18    pub do_predict: bool,
19
20    /// Number of steps used for a linear warmup from 0 to learning_rate
21    pub warmup_steps: usize,
22
23    /// Ratio of total training steps used for a linear warmup from 0 to learning_rate
24    pub warmup_ratio: f32,
25
26    /// Learning rate for the optimizer
27    pub learning_rate: f32,
28
29    /// Weight decay coefficient for regularization
30    pub weight_decay: f32,
31
32    /// Beta1 hyperparameter for the Adam optimizer
33    pub adam_beta1: f32,
34
35    /// Beta2 hyperparameter for the Adam optimizer
36    pub adam_beta2: f32,
37
38    /// Epsilon hyperparameter for the Adam optimizer
39    pub adam_epsilon: f32,
40
41    /// Maximum gradient norm for gradient clipping
42    pub max_grad_norm: f32,
43
44    /// Total number of training epochs to perform
45    pub num_train_epochs: f32,
46
47    /// Total number of training steps to perform (overrides num_train_epochs if set)
48    pub max_steps: Option<usize>,
49
50    /// Number of updates steps to accumulate before performing a backward/update pass
51    pub gradient_accumulation_steps: usize,
52
53    /// Batch size per device during training
54    pub per_device_train_batch_size: usize,
55
56    /// Batch size per device during evaluation
57    pub per_device_eval_batch_size: usize,
58
59    /// Number of subprocesses to use for data loading
60    pub dataloader_num_workers: usize,
61
62    /// Whether to pin memory in data loaders
63    pub dataloader_pin_memory: bool,
64
65    /// How often to save the model checkpoint
66    pub save_steps: usize,
67
68    /// Maximum number of checkpoints to keep
69    pub save_total_limit: Option<usize>,
70
71    /// How often to log training metrics
72    pub logging_steps: usize,
73
74    /// How often to run evaluation
75    pub eval_steps: usize,
76
77    /// Whether to run evaluation at the end of training
78    pub eval_at_end: bool,
79
80    /// Random seed for initialization
81    pub seed: u64,
82
83    /// Whether to use 16-bit mixed precision training
84    pub fp16: bool,
85
86    /// Whether to use bfloat16 mixed precision training
87    pub bf16: bool,
88
89    /// The name of the metric to use to compare two different models
90    pub metric_for_best_model: Option<String>,
91
92    /// Whether the metric_for_best_model should be maximized or not
93    pub greater_is_better: Option<bool>,
94
95    /// How many evaluation calls to wait before stopping training
96    pub early_stopping_patience: Option<usize>,
97
98    /// Minimum change in the monitored metric to qualify as an improvement
99    pub early_stopping_threshold: Option<f32>,
100
101    /// Whether to load the best model found during training at the end of training
102    pub load_best_model_at_end: bool,
103
104    /// Strategy to adopt during evaluation
105    pub evaluation_strategy: EvaluationStrategy,
106
107    /// Strategy to adopt for saving checkpoints
108    pub save_strategy: SaveStrategy,
109
110    /// The logging directory to use
111    pub logging_dir: Option<PathBuf>,
112
113    /// Whether to run training
114    pub do_train: bool,
115
116    /// Resume training from a checkpoint
117    pub resume_from_checkpoint: Option<PathBuf>,
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
121pub enum EvaluationStrategy {
122    /// No evaluation during training
123    No,
124    /// Evaluate every eval_steps
125    Steps,
126    /// Evaluate at the end of each epoch
127    Epoch,
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
131pub enum SaveStrategy {
132    /// No saving during training
133    No,
134    /// Save every save_steps
135    Steps,
136    /// Save at the end of each epoch
137    Epoch,
138}
139
140impl Default for TrainingArguments {
141    fn default() -> Self {
142        Self {
143            output_dir: PathBuf::from("./results"),
144            overwrite_output_dir: false,
145            do_eval: false,
146            do_predict: false,
147            warmup_steps: 0,
148            warmup_ratio: 0.0,
149            learning_rate: 5e-5,
150            weight_decay: 0.0,
151            adam_beta1: 0.9,
152            adam_beta2: 0.999,
153            adam_epsilon: 1e-8,
154            max_grad_norm: 1.0,
155            num_train_epochs: 3.0,
156            max_steps: None,
157            gradient_accumulation_steps: 1,
158            per_device_train_batch_size: 8,
159            per_device_eval_batch_size: 8,
160            dataloader_num_workers: 0,
161            dataloader_pin_memory: false,
162            save_steps: 500,
163            save_total_limit: None,
164            logging_steps: 10,
165            eval_steps: 500,
166            eval_at_end: true,
167            seed: 42,
168            fp16: false,
169            bf16: false,
170            metric_for_best_model: None,
171            greater_is_better: None,
172            early_stopping_patience: None,
173            early_stopping_threshold: None,
174            load_best_model_at_end: false,
175            evaluation_strategy: EvaluationStrategy::No,
176            save_strategy: SaveStrategy::Steps,
177            logging_dir: None,
178            do_train: true,
179            resume_from_checkpoint: None,
180        }
181    }
182}
183
184impl TrainingArguments {
185    /// Create a new TrainingArguments with the specified output directory
186    pub fn new(output_dir: impl Into<PathBuf>) -> Self {
187        Self {
188            output_dir: output_dir.into(),
189            ..Default::default()
190        }
191    }
192
193    /// Calculate the total number of training steps
194    pub fn get_total_steps(&self, num_examples: usize) -> usize {
195        if let Some(max_steps) = self.max_steps {
196            max_steps
197        } else {
198            let steps_per_epoch = num_examples.div_ceil(self.per_device_train_batch_size);
199            (self.num_train_epochs * steps_per_epoch as f32) as usize
200        }
201    }
202
203    /// Calculate the effective batch size (accounting for gradient accumulation)
204    pub fn get_effective_batch_size(&self) -> usize {
205        self.per_device_train_batch_size * self.gradient_accumulation_steps
206    }
207
208    /// Calculate the number of warmup steps
209    pub fn get_warmup_steps(&self, total_steps: usize) -> usize {
210        if self.warmup_steps > 0 {
211            self.warmup_steps
212        } else {
213            (self.warmup_ratio * total_steps as f32) as usize
214        }
215    }
216
217    /// Validate the training arguments
218    pub fn validate(&self) -> Result<()> {
219        if self.per_device_train_batch_size == 0 {
220            return Err(invalid_config(
221                "per_device_train_batch_size",
222                "must be greater than 0",
223            ));
224        }
225
226        if self.per_device_eval_batch_size == 0 {
227            return Err(invalid_config(
228                "per_device_eval_batch_size",
229                "must be greater than 0",
230            ));
231        }
232
233        if self.gradient_accumulation_steps == 0 {
234            return Err(invalid_config(
235                "gradient_accumulation_steps",
236                "must be greater than 0",
237            ));
238        }
239
240        if self.learning_rate <= 0.0 {
241            return Err(invalid_config("learning_rate", "must be positive"));
242        }
243
244        if self.num_train_epochs <= 0.0 && self.max_steps.is_none() {
245            return Err(invalid_config(
246                "training_schedule",
247                "either num_train_epochs or max_steps must be positive",
248            ));
249        }
250
251        Ok(())
252    }
253}