Skip to main content

TensorFlowOptimizer

Trait TensorFlowOptimizer 

Source
pub trait TensorFlowOptimizer: Send + Sync {
    // Required methods
    fn apply_gradients(
        &mut self,
        grads_and_vars: &[(Tensor, String)],
        global_step: Option<i64>,
    ) -> Result<()>;
    fn minimize(
        &mut self,
        loss_fn: Box<dyn Fn() -> Result<Tensor>>,
        var_list: &[String],
        global_step: Option<i64>,
    ) -> Result<Tensor>;
    fn get_config(&self) -> TensorFlowOptimizerConfig;
    fn variables(&self) -> Vec<String>;
    fn get_weights(&self) -> Vec<Tensor>;
    fn set_weights(&mut self, weights: Vec<Tensor>) -> Result<()>;
    fn get_learning_rate(&self) -> f64;
    fn set_learning_rate(&mut self, lr: f64) -> Result<()>;
    fn get_name(&self) -> &str;
}
Expand description

TensorFlow-compatible optimizer interface

Required Methods§

Source

fn apply_gradients( &mut self, grads_and_vars: &[(Tensor, String)], global_step: Option<i64>, ) -> Result<()>

Apply gradients to variables

Source

fn minimize( &mut self, loss_fn: Box<dyn Fn() -> Result<Tensor>>, var_list: &[String], global_step: Option<i64>, ) -> Result<Tensor>

Minimize loss function

Source

fn get_config(&self) -> TensorFlowOptimizerConfig

Get optimizer configuration

Source

fn variables(&self) -> Vec<String>

Get optimizer variables (state)

Source

fn get_weights(&self) -> Vec<Tensor>

Get optimizer weights

Source

fn set_weights(&mut self, weights: Vec<Tensor>) -> Result<()>

Set optimizer weights

Source

fn get_learning_rate(&self) -> f64

Get learning rate

Source

fn set_learning_rate(&mut self, lr: f64) -> Result<()>

Set learning rate

Source

fn get_name(&self) -> &str

Get optimizer name

Implementors§