Trait ParameterOptimizer

Source
pub trait ParameterOptimizer<A: Float, D: Dimension> {
    // Required methods
    fn register_parameter(
        &mut self,
        param_id: ParamId,
        parameter: &Array<A, D>,
        metadata: ParameterMetadata,
    ) -> Result<()>;
    fn step(
        &mut self,
        gradients: HashMap<ParamId, Array<A, D>>,
        parameters: &mut HashMap<ParamId, Array<A, D>>,
    ) -> Result<()>;
    fn get_learning_rate(&self, param_id: &ParamId) -> Option<A>;
    fn set_learning_rate(&mut self, param_id: &ParamId, lr: A) -> Result<()>;
    fn get_parameter_state(
        &self,
        param_id: &ParamId,
    ) -> Option<&HashMap<String, Array<A, D>>>;
    fn reset_state(&mut self);
    fn registered_parameters(&self) -> Vec<ParamId> ;
}
Expand description

Generic parameter optimization interface

Required Methods§

Source

fn register_parameter( &mut self, param_id: ParamId, parameter: &Array<A, D>, metadata: ParameterMetadata, ) -> Result<()>

Register a parameter for optimization

Source

fn step( &mut self, gradients: HashMap<ParamId, Array<A, D>>, parameters: &mut HashMap<ParamId, Array<A, D>>, ) -> Result<()>

Update registered parameters with gradients

Source

fn get_learning_rate(&self, param_id: &ParamId) -> Option<A>

Get parameter-specific learning rate

Source

fn set_learning_rate(&mut self, param_id: &ParamId, lr: A) -> Result<()>

Set parameter-specific learning rate

Source

fn get_parameter_state( &self, param_id: &ParamId, ) -> Option<&HashMap<String, Array<A, D>>>

Get optimizer state for a parameter

Source

fn reset_state(&mut self)

Reset optimizer state

Source

fn registered_parameters(&self) -> Vec<ParamId>

Get all registered parameter IDs

Implementors§