pub trait ParameterOptimizer<A: Float, D: Dimension> {
// Required methods
fn register_parameter(
&mut self,
param_id: ParamId,
parameter: &Array<A, D>,
metadata: ParameterMetadata,
) -> Result<()>;
fn step(
&mut self,
gradients: HashMap<ParamId, Array<A, D>>,
parameters: &mut HashMap<ParamId, Array<A, D>>,
) -> Result<()>;
fn get_learning_rate(&self, param_id: &ParamId) -> Option<A>;
fn set_learning_rate(&mut self, param_id: &ParamId, lr: A) -> Result<()>;
fn get_parameter_state(
&self,
param_id: &ParamId,
) -> Option<&HashMap<String, Array<A, D>>>;
fn reset_state(&mut self);
fn registered_parameters(&self) -> Vec<ParamId> ⓘ;
}
Expand description
Generic parameter optimization interface
Required Methods§
Sourcefn register_parameter(
&mut self,
param_id: ParamId,
parameter: &Array<A, D>,
metadata: ParameterMetadata,
) -> Result<()>
fn register_parameter( &mut self, param_id: ParamId, parameter: &Array<A, D>, metadata: ParameterMetadata, ) -> Result<()>
Register a parameter for optimization
Sourcefn step(
&mut self,
gradients: HashMap<ParamId, Array<A, D>>,
parameters: &mut HashMap<ParamId, Array<A, D>>,
) -> Result<()>
fn step( &mut self, gradients: HashMap<ParamId, Array<A, D>>, parameters: &mut HashMap<ParamId, Array<A, D>>, ) -> Result<()>
Update registered parameters with gradients
Sourcefn get_learning_rate(&self, param_id: &ParamId) -> Option<A>
fn get_learning_rate(&self, param_id: &ParamId) -> Option<A>
Get parameter-specific learning rate
Sourcefn set_learning_rate(&mut self, param_id: &ParamId, lr: A) -> Result<()>
fn set_learning_rate(&mut self, param_id: &ParamId, lr: A) -> Result<()>
Set parameter-specific learning rate
Sourcefn get_parameter_state(
&self,
param_id: &ParamId,
) -> Option<&HashMap<String, Array<A, D>>>
fn get_parameter_state( &self, param_id: &ParamId, ) -> Option<&HashMap<String, Array<A, D>>>
Get optimizer state for a parameter
Sourcefn reset_state(&mut self)
fn reset_state(&mut self)
Reset optimizer state
Sourcefn registered_parameters(&self) -> Vec<ParamId> ⓘ
fn registered_parameters(&self) -> Vec<ParamId> ⓘ
Get all registered parameter IDs