pub trait TensorFlowOptimizer: Send + Sync {
// Required methods
fn apply_gradients(
&mut self,
grads_and_vars: &[(Tensor, String)],
global_step: Option<i64>,
) -> Result<()>;
fn minimize(
&mut self,
loss_fn: Box<dyn Fn() -> Result<Tensor>>,
var_list: &[String],
global_step: Option<i64>,
) -> Result<Tensor>;
fn get_config(&self) -> TensorFlowOptimizerConfig;
fn variables(&self) -> Vec<String>;
fn get_weights(&self) -> Vec<Tensor>;
fn set_weights(&mut self, weights: Vec<Tensor>) -> Result<()>;
fn get_learning_rate(&self) -> f64;
fn set_learning_rate(&mut self, lr: f64) -> Result<()>;
fn get_name(&self) -> &str;
}Expand description
TensorFlow-compatible optimizer interface
Required Methods§
Sourcefn apply_gradients(
&mut self,
grads_and_vars: &[(Tensor, String)],
global_step: Option<i64>,
) -> Result<()>
fn apply_gradients( &mut self, grads_and_vars: &[(Tensor, String)], global_step: Option<i64>, ) -> Result<()>
Apply gradients to variables
Sourcefn minimize(
&mut self,
loss_fn: Box<dyn Fn() -> Result<Tensor>>,
var_list: &[String],
global_step: Option<i64>,
) -> Result<Tensor>
fn minimize( &mut self, loss_fn: Box<dyn Fn() -> Result<Tensor>>, var_list: &[String], global_step: Option<i64>, ) -> Result<Tensor>
Minimize loss function
Sourcefn get_config(&self) -> TensorFlowOptimizerConfig
fn get_config(&self) -> TensorFlowOptimizerConfig
Get optimizer configuration
Sourcefn get_weights(&self) -> Vec<Tensor>
fn get_weights(&self) -> Vec<Tensor>
Get optimizer weights
Sourcefn get_learning_rate(&self) -> f64
fn get_learning_rate(&self) -> f64
Get learning rate
Sourcefn set_learning_rate(&mut self, lr: f64) -> Result<()>
fn set_learning_rate(&mut self, lr: f64) -> Result<()>
Set learning rate