tiny_recursive_rs/training/
trainer.rs1use candle_core::{Result, Tensor, Device, DType};
3use candle_nn::{VarMap, VarBuilder, AdamW, ParamsAdamW, Optimizer, loss, ops};
4use std::path::Path;
5
6use crate::{TinyRecursiveModel, TRMConfig};
7use crate::data::{NumpyDataLoader, BatchDataLoader};
8use crate::models::InnerCarry;
9use super::scheduler::CosineScheduler;
10use super::ema::{EMA, EMAConfig};
11use super::checkpoint::{Checkpoint, CheckpointMetadata};
12
13#[derive(Debug, Clone)]
15pub struct TrainingConfig {
16 pub num_epochs: usize,
18 pub batch_size: usize,
20 pub learning_rate: f64,
22 pub lr_min: f64,
24 pub warmup_steps: usize,
26 pub total_steps: usize,
28 pub weight_decay: f64,
30 pub grad_clip: Option<f64>,
32 pub ema_decay: f64,
34 pub save_every: usize,
36 pub eval_every: usize,
38 pub checkpoint_dir: String,
40}
41
42impl Default for TrainingConfig {
43 fn default() -> Self {
44 Self {
45 num_epochs: 10,
46 batch_size: 32,
47 learning_rate: 3e-4,
48 lr_min: 3e-5,
49 warmup_steps: 1000,
50 total_steps: 100000,
51 weight_decay: 0.1,
52 grad_clip: Some(1.0),
53 ema_decay: 0.9999,
54 save_every: 1000,
55 eval_every: 500,
56 checkpoint_dir: "checkpoints".to_string(),
57 }
58 }
59}
60
61pub struct Trainer {
63 model: TinyRecursiveModel,
64 model_config: TRMConfig,
65 varmap: VarMap,
66 optimizer: AdamW,
67 scheduler: CosineScheduler,
68 ema: Option<EMA>,
69 config: TrainingConfig,
70 device: Device,
71 step: usize,
72}
73
74impl Trainer {
75 pub fn new(
77 model_config: TRMConfig,
78 training_config: TrainingConfig,
79 device: Device,
80 ) -> Result<Self> {
81 let dtype = if device.is_cuda() { DType::F16 } else { DType::F32 };
83 let varmap = VarMap::new();
84 let vb = VarBuilder::from_varmap(&varmap, dtype, &device);
85 let model = TinyRecursiveModel::new(model_config.clone(), vb)
86 .map_err(|e| candle_core::Error::Msg(format!("Model init failed: {:?}", e)))?;
87
88 let optimizer_params = ParamsAdamW {
90 lr: training_config.learning_rate,
91 beta1: 0.9,
92 beta2: 0.999,
93 eps: 1e-8,
94 weight_decay: training_config.weight_decay,
95 };
96 let optimizer = AdamW::new(varmap.all_vars(), optimizer_params)?;
97
98 let scheduler = CosineScheduler::new(super::scheduler::CosineSchedulerConfig {
100 lr_init: training_config.learning_rate,
101 lr_min: training_config.lr_min,
102 warmup_steps: training_config.warmup_steps,
103 total_steps: training_config.total_steps,
104 });
105
106 let ema = None;
108
109 Ok(Self {
110 model,
111 model_config,
112 varmap,
113 optimizer,
114 scheduler,
115 ema,
116 config: training_config,
117 device,
118 step: 0,
119 })
120 }
121
122 fn compute_loss(
124 &self,
125 logits: &Tensor,
126 targets: &Tensor,
127 ) -> Result<Tensor> {
128 let batch_size = logits.dim(0)?;
136 let seq_len = logits.dim(1)?;
137 let num_classes = logits.dim(2)?;
138
139 let target_shape = targets.dims();
140
141 if target_shape.len() == 2 && target_shape[1] == 1 {
143 let logits_pooled = logits.mean(1)?; let targets_flat = targets.flatten_all()?;
149
150 let log_probs = ops::log_softmax(&logits_pooled, candle_core::D::Minus1)?
152 .to_dtype(DType::F32)?;
153
154 let mut loss_sum = 0.0f32;
156 for i in 0..batch_size {
157 let target_idx = targets_flat.get(i)?.to_scalar::<u32>()? as usize;
158 let log_prob = log_probs.get(i)?.get(target_idx)?.to_scalar::<f32>()?;
159 loss_sum -= log_prob;
160 }
161
162 let loss_val = loss_sum / batch_size as f32;
163 Tensor::from_slice(&[loss_val], 1, &self.device)?.squeeze(0)
164 } else {
165 let logits_flat = logits.reshape((batch_size * seq_len, num_classes))?;
167 let targets_flat = targets.flatten_all()?;
168
169 let log_probs = ops::log_softmax(&logits_flat, candle_core::D::Minus1)?
171 .to_dtype(DType::F32)?;
172
173 let mut loss_sum = 0.0f32;
175 for i in 0..(batch_size * seq_len) {
176 let target_idx = targets_flat.get(i)?.to_scalar::<u32>()? as usize;
177 let log_prob = log_probs.get(i)?.get(target_idx)?.to_scalar::<f32>()?;
178 loss_sum -= log_prob;
179 }
180
181 let loss_val = loss_sum / (batch_size * seq_len) as f32;
182 Tensor::from_slice(&[loss_val], 1, &self.device)?.squeeze(0)
183 }
184 }
185
186 pub fn train_step(
188 &mut self,
189 input_ids: &Tensor,
190 target_ids: &Tensor,
191 ) -> Result<f32> {
192 let batch_size = input_ids.dim(0)?;
194 let seq_len = input_ids.dim(1)?;
195
196 log::debug!("Input dtype: {:?}, Target dtype: {:?}", input_ids.dtype(), target_ids.dtype());
197
198 let dtype = if self.device.is_cuda() { DType::F16 } else { DType::F32 };
200 let carry = InnerCarry::empty(
201 batch_size,
202 seq_len,
203 self.model_config.hidden_size,
204 dtype,
205 &self.device,
206 )?;
207
208 log::debug!("Running forward pass...");
210 let (_new_carry, logits) = self.model.forward(&carry, input_ids)
211 .map_err(|e| candle_core::Error::Msg(format!("Forward pass failed: {:?}", e)))?;
212
213 log::debug!("Logits shape: {:?}, dtype: {:?}", logits.dims(), logits.dtype());
214
215 log::debug!("Computing loss...");
217 let loss = self.compute_loss(&logits, target_ids)
218 .map_err(|e| candle_core::Error::Msg(format!("Loss computation failed: {:?}", e)))?;
219 let loss_val = loss.to_scalar::<f32>()?;
220
221 let lr = self.scheduler.get_lr();
223 self.optimizer.set_learning_rate(lr);
224
225 self.optimizer.backward_step(&loss)?;
228
229 self.scheduler.step();
231
232 self.step += 1;
235
236 Ok(loss_val)
237 }
238
239 pub fn save_checkpoint<P: AsRef<Path>>(&self, path: P, loss: Option<f64>) -> Result<()> {
241 std::fs::create_dir_all(&self.config.checkpoint_dir)
242 .map_err(|e| candle_core::Error::Msg(format!("Failed to create checkpoint dir: {}", e)))?;
243
244 self.varmap.save(path.as_ref())?;
246
247 let metadata = CheckpointMetadata {
249 step: self.step,
250 lr: self.scheduler.get_lr(),
251 loss,
252 config: None,
253 };
254
255 let metadata_path = format!("{}.meta.json", path.as_ref().display());
256 let metadata_json = serde_json::to_string_pretty(&metadata)
257 .map_err(|e| candle_core::Error::Msg(format!("Metadata serialization failed: {}", e)))?;
258 std::fs::write(&metadata_path, metadata_json)
259 .map_err(|e| candle_core::Error::Msg(format!("Metadata write failed: {}", e)))?;
260
261 log::debug!("Saved checkpoint weights to {} and metadata to {}", path.as_ref().display(), metadata_path);
262
263 Ok(())
264 }
265
266 pub fn train_epoch(&mut self, dataloader: &mut impl BatchDataLoader) -> Result<f32> {
268 let mut total_loss = 0.0;
269 let mut num_batches = 0;
270
271 dataloader.reset();
272
273 while let Some((input_ids, target_ids)) = dataloader.next_batch(&self.device)? {
274 let loss = self.train_step(&input_ids, &target_ids)?;
275 total_loss += loss;
276 num_batches += 1;
277
278 if self.step % 100 == 0 {
280 log::info!(
281 "Step {}: loss={:.4}, lr={:.6}",
282 self.step,
283 loss,
284 self.scheduler.get_lr()
285 );
286 }
287
288 if self.step % self.config.save_every == 0 {
290 let checkpoint_path = format!(
291 "{}/checkpoint_step_{}.safetensors",
292 self.config.checkpoint_dir,
293 self.step
294 );
295 log::info!("Saving checkpoint to {}", checkpoint_path);
296 self.save_checkpoint(&checkpoint_path, Some(loss as f64))?;
297 }
298 }
299
300 let avg_loss = total_loss / num_batches as f32;
301 Ok(avg_loss)
302 }
303
304 pub fn train(&mut self, dataloader: &mut impl BatchDataLoader) -> Result<()> {
306 log::info!("Starting training for {} epochs", self.config.num_epochs);
307 log::info!("Total batches per epoch: {}", dataloader.num_batches());
308
309 for epoch in 0..self.config.num_epochs {
310 log::info!("=== Epoch {}/{} ===", epoch + 1, self.config.num_epochs);
311
312 let avg_loss = self.train_epoch(dataloader)?;
313
314 log::info!(
315 "Epoch {} complete: avg_loss={:.4}, step={}",
316 epoch + 1,
317 avg_loss,
318 self.step
319 );
320
321 let checkpoint_path = format!(
323 "{}/checkpoint_epoch_{}.safetensors",
324 self.config.checkpoint_dir,
325 epoch + 1
326 );
327 self.save_checkpoint(&checkpoint_path, Some(avg_loss as f64))?;
328 }
329
330 log::info!("Training complete!");
331
332 let final_path = format!("{}/final_model.safetensors", self.config.checkpoint_dir);
334 log::info!("Saving final model to {}", final_path);
335 self.varmap.save(&final_path)?;
336
337 Ok(())
338 }
339}