Skip to main content

trustformers_optim/
optimizer.rs

1use anyhow::Result;
2use std::collections::HashMap;
3use trustformers_core::tensor::Tensor;
4
5/// Trait for optimizer state management and parameter updates.
6pub trait OptimizerState {
7    /// Zero out gradients
8    fn zero_grad(&mut self) -> Result<()>;
9
10    /// Perform optimization step
11    fn step(&mut self, parameters: &mut [Tensor]) -> Result<()>;
12
13    /// Get current learning rate
14    fn get_lr(&self) -> f32;
15
16    /// Set learning rate
17    fn set_lr(&mut self, lr: f32);
18
19    /// Save optimizer state to dictionary
20    fn state_dict(&self) -> Result<HashMap<String, Tensor>>;
21
22    /// Load optimizer state from dictionary
23    fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()>;
24}