trustformers_optim/optimizer.rs
1use anyhow::Result;
2use std::collections::HashMap;
3use trustformers_core::tensor::Tensor;
4
5/// Trait for optimizer state management and parameter updates.
6pub trait OptimizerState {
7 /// Zero out gradients
8 fn zero_grad(&mut self) -> Result<()>;
9
10 /// Perform optimization step
11 fn step(&mut self, parameters: &mut [Tensor]) -> Result<()>;
12
13 /// Get current learning rate
14 fn get_lr(&self) -> f32;
15
16 /// Set learning rate
17 fn set_lr(&mut self, lr: f32);
18
19 /// Save optimizer state to dictionary
20 fn state_dict(&self) -> Result<HashMap<String, Tensor>>;
21
22 /// Load optimizer state from dictionary
23 fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()>;
24}