tiny_recursive_rs/training/
trainer.rs

1/// Training loop for TinyRecursiveModel
2use candle_core::{Result, Tensor, Device, DType};
3use candle_nn::{VarMap, VarBuilder, AdamW, ParamsAdamW, Optimizer, loss, ops};
4use std::path::Path;
5
6use crate::{TinyRecursiveModel, TRMConfig};
7use crate::data::{NumpyDataLoader, BatchDataLoader};
8use crate::models::InnerCarry;
9use super::scheduler::CosineScheduler;
10use super::ema::{EMA, EMAConfig};
11use super::checkpoint::{Checkpoint, CheckpointMetadata};
12
13/// Training configuration
14#[derive(Debug, Clone)]
15pub struct TrainingConfig {
16    /// Number of training epochs
17    pub num_epochs: usize,
18    /// Batch size
19    pub batch_size: usize,
20    /// Learning rate (initial)
21    pub learning_rate: f64,
22    /// Minimum learning rate
23    pub lr_min: f64,
24    /// Warmup steps
25    pub warmup_steps: usize,
26    /// Total training steps (for scheduler)
27    pub total_steps: usize,
28    /// Weight decay
29    pub weight_decay: f64,
30    /// Gradient clipping value
31    pub grad_clip: Option<f64>,
32    /// EMA decay
33    pub ema_decay: f64,
34    /// Save checkpoint every N steps
35    pub save_every: usize,
36    /// Evaluation every N steps
37    pub eval_every: usize,
38    /// Checkpoint directory
39    pub checkpoint_dir: String,
40}
41
42impl Default for TrainingConfig {
43    fn default() -> Self {
44        Self {
45            num_epochs: 10,
46            batch_size: 32,
47            learning_rate: 3e-4,
48            lr_min: 3e-5,
49            warmup_steps: 1000,
50            total_steps: 100000,
51            weight_decay: 0.1,
52            grad_clip: Some(1.0),
53            ema_decay: 0.9999,
54            save_every: 1000,
55            eval_every: 500,
56            checkpoint_dir: "checkpoints".to_string(),
57        }
58    }
59}
60
61/// Trainer for TinyRecursiveModel
62pub struct Trainer {
63    model: TinyRecursiveModel,
64    model_config: TRMConfig,
65    varmap: VarMap,
66    optimizer: AdamW,
67    scheduler: CosineScheduler,
68    ema: Option<EMA>,
69    config: TrainingConfig,
70    device: Device,
71    step: usize,
72}
73
74impl Trainer {
75    /// Create new trainer
76    pub fn new(
77        model_config: TRMConfig,
78        training_config: TrainingConfig,
79        device: Device,
80    ) -> Result<Self> {
81        // Create model with F16 for speed (GPU has dedicated F16 cores)
82        let dtype = if device.is_cuda() { DType::F16 } else { DType::F32 };
83        let varmap = VarMap::new();
84        let vb = VarBuilder::from_varmap(&varmap, dtype, &device);
85        let model = TinyRecursiveModel::new(model_config.clone(), vb)
86            .map_err(|e| candle_core::Error::Msg(format!("Model init failed: {:?}", e)))?;
87
88        // Create optimizer using candle's built-in AdamW
89        let optimizer_params = ParamsAdamW {
90            lr: training_config.learning_rate,
91            beta1: 0.9,
92            beta2: 0.999,
93            eps: 1e-8,
94            weight_decay: training_config.weight_decay,
95        };
96        let optimizer = AdamW::new(varmap.all_vars(), optimizer_params)?;
97
98        // Create scheduler
99        let scheduler = CosineScheduler::new(super::scheduler::CosineSchedulerConfig {
100            lr_init: training_config.learning_rate,
101            lr_min: training_config.lr_min,
102            warmup_steps: training_config.warmup_steps,
103            total_steps: training_config.total_steps,
104        });
105
106        // EMA disabled for training speed
107        let ema = None;
108
109        Ok(Self {
110            model,
111            model_config,
112            varmap,
113            optimizer,
114            scheduler,
115            ema,
116            config: training_config,
117            device,
118            step: 0,
119        })
120    }
121
122    /// Compute loss for a batch
123    fn compute_loss(
124        &self,
125        logits: &Tensor,
126        targets: &Tensor,
127    ) -> Result<Tensor> {
128        // For opcode classification:
129        // logits shape: [batch, seq_len, num_classes]
130        // targets shape: [batch, 1] or [batch]
131
132        // If targets are for sequence modeling (seq_len > 1), use all positions
133        // If targets are classification labels (seq_len = 1 or single label), use pooled representation
134
135        let batch_size = logits.dim(0)?;
136        let seq_len = logits.dim(1)?;
137        let num_classes = logits.dim(2)?;
138
139        let target_shape = targets.dims();
140
141        // Check if this is classification (single label per example) or sequence modeling
142        if target_shape.len() == 2 && target_shape[1] == 1 {
143            // Classification task: targets shape [batch, 1]
144            // Pool logits across sequence (mean pooling)
145            let logits_pooled = logits.mean(1)?; // [batch, num_classes]
146
147            // Flatten targets to [batch]
148            let targets_flat = targets.flatten_all()?;
149
150            // Compute log_softmax and convert to F32 for loss computation
151            let log_probs = ops::log_softmax(&logits_pooled, candle_core::D::Minus1)?
152                .to_dtype(DType::F32)?;
153
154            // Gather log probs at target indices and compute negative log likelihood
155            let mut loss_sum = 0.0f32;
156            for i in 0..batch_size {
157                let target_idx = targets_flat.get(i)?.to_scalar::<u32>()? as usize;
158                let log_prob = log_probs.get(i)?.get(target_idx)?.to_scalar::<f32>()?;
159                loss_sum -= log_prob;
160            }
161
162            let loss_val = loss_sum / batch_size as f32;
163            Tensor::from_slice(&[loss_val], 1, &self.device)?.squeeze(0)
164        } else {
165            // Sequence modeling task: targets shape [batch, seq_len]
166            let logits_flat = logits.reshape((batch_size * seq_len, num_classes))?;
167            let targets_flat = targets.flatten_all()?;
168
169            // Compute log_softmax and convert to F32 for loss computation
170            let log_probs = ops::log_softmax(&logits_flat, candle_core::D::Minus1)?
171                .to_dtype(DType::F32)?;
172
173            // Gather log probs at target indices
174            let mut loss_sum = 0.0f32;
175            for i in 0..(batch_size * seq_len) {
176                let target_idx = targets_flat.get(i)?.to_scalar::<u32>()? as usize;
177                let log_prob = log_probs.get(i)?.get(target_idx)?.to_scalar::<f32>()?;
178                loss_sum -= log_prob;
179            }
180
181            let loss_val = loss_sum / (batch_size * seq_len) as f32;
182            Tensor::from_slice(&[loss_val], 1, &self.device)?.squeeze(0)
183        }
184    }
185
186    /// Training step
187    pub fn train_step(
188        &mut self,
189        input_ids: &Tensor,
190        target_ids: &Tensor,
191    ) -> Result<f32> {
192        // Get batch size and sequence length
193        let batch_size = input_ids.dim(0)?;
194        let seq_len = input_ids.dim(1)?;
195
196        log::debug!("Input dtype: {:?}, Target dtype: {:?}", input_ids.dtype(), target_ids.dtype());
197
198        // Create initial carry with correct dtype (F16 on GPU, F32 on CPU)
199        let dtype = if self.device.is_cuda() { DType::F16 } else { DType::F32 };
200        let carry = InnerCarry::empty(
201            batch_size,
202            seq_len,
203            self.model_config.hidden_size,
204            dtype,
205            &self.device,
206        )?;
207
208        // Forward pass
209        log::debug!("Running forward pass...");
210        let (_new_carry, logits) = self.model.forward(&carry, input_ids)
211            .map_err(|e| candle_core::Error::Msg(format!("Forward pass failed: {:?}", e)))?;
212
213        log::debug!("Logits shape: {:?}, dtype: {:?}", logits.dims(), logits.dtype());
214
215        // Compute loss
216        log::debug!("Computing loss...");
217        let loss = self.compute_loss(&logits, target_ids)
218            .map_err(|e| candle_core::Error::Msg(format!("Loss computation failed: {:?}", e)))?;
219        let loss_val = loss.to_scalar::<f32>()?;
220
221        // Update learning rate before optimizer step
222        let lr = self.scheduler.get_lr();
223        self.optimizer.set_learning_rate(lr);
224
225        // Backward pass + parameter update (all in one!)
226        // This is THE KEY: backward_step() computes gradients AND updates parameters in-place
227        self.optimizer.backward_step(&loss)?;
228
229        // Scheduler step
230        self.scheduler.step();
231
232        // EMA disabled for speed
233
234        self.step += 1;
235
236        Ok(loss_val)
237    }
238
239    /// Save checkpoint
240    pub fn save_checkpoint<P: AsRef<Path>>(&self, path: P, loss: Option<f64>) -> Result<()> {
241        std::fs::create_dir_all(&self.config.checkpoint_dir)
242            .map_err(|e| candle_core::Error::Msg(format!("Failed to create checkpoint dir: {}", e)))?;
243
244        // Save weights using varmap.save() - this actually saves the tensors
245        self.varmap.save(path.as_ref())?;
246
247        // Save metadata separately as JSON sidecar
248        let metadata = CheckpointMetadata {
249            step: self.step,
250            lr: self.scheduler.get_lr(),
251            loss,
252            config: None,
253        };
254
255        let metadata_path = format!("{}.meta.json", path.as_ref().display());
256        let metadata_json = serde_json::to_string_pretty(&metadata)
257            .map_err(|e| candle_core::Error::Msg(format!("Metadata serialization failed: {}", e)))?;
258        std::fs::write(&metadata_path, metadata_json)
259            .map_err(|e| candle_core::Error::Msg(format!("Metadata write failed: {}", e)))?;
260
261        log::debug!("Saved checkpoint weights to {} and metadata to {}", path.as_ref().display(), metadata_path);
262
263        Ok(())
264    }
265
266    /// Train for one epoch
267    pub fn train_epoch(&mut self, dataloader: &mut impl BatchDataLoader) -> Result<f32> {
268        let mut total_loss = 0.0;
269        let mut num_batches = 0;
270
271        dataloader.reset();
272
273        while let Some((input_ids, target_ids)) = dataloader.next_batch(&self.device)? {
274            let loss = self.train_step(&input_ids, &target_ids)?;
275            total_loss += loss;
276            num_batches += 1;
277
278            // Log every 100 batches instead of every batch for speed
279            if self.step % 100 == 0 {
280                log::info!(
281                    "Step {}: loss={:.4}, lr={:.6}",
282                    self.step,
283                    loss,
284                    self.scheduler.get_lr()
285                );
286            }
287
288            // Save checkpoint
289            if self.step % self.config.save_every == 0 {
290                let checkpoint_path = format!(
291                    "{}/checkpoint_step_{}.safetensors",
292                    self.config.checkpoint_dir,
293                    self.step
294                );
295                log::info!("Saving checkpoint to {}", checkpoint_path);
296                self.save_checkpoint(&checkpoint_path, Some(loss as f64))?;
297            }
298        }
299
300        let avg_loss = total_loss / num_batches as f32;
301        Ok(avg_loss)
302    }
303
304    /// Full training loop
305    pub fn train(&mut self, dataloader: &mut impl BatchDataLoader) -> Result<()> {
306        log::info!("Starting training for {} epochs", self.config.num_epochs);
307        log::info!("Total batches per epoch: {}", dataloader.num_batches());
308
309        for epoch in 0..self.config.num_epochs {
310            log::info!("=== Epoch {}/{} ===", epoch + 1, self.config.num_epochs);
311
312            let avg_loss = self.train_epoch(dataloader)?;
313
314            log::info!(
315                "Epoch {} complete: avg_loss={:.4}, step={}",
316                epoch + 1,
317                avg_loss,
318                self.step
319            );
320
321            // Save epoch checkpoint
322            let checkpoint_path = format!(
323                "{}/checkpoint_epoch_{}.safetensors",
324                self.config.checkpoint_dir,
325                epoch + 1
326            );
327            self.save_checkpoint(&checkpoint_path, Some(avg_loss as f64))?;
328        }
329
330        log::info!("Training complete!");
331
332        // Save final model
333        let final_path = format!("{}/final_model.safetensors", self.config.checkpoint_dir);
334        log::info!("Saving final model to {}", final_path);
335        self.varmap.save(&final_path)?;
336
337        Ok(())
338    }
339}