use super::var_store::{VarStore, Variables};
use crate::wrappers::optimizer::COptimizer;
use crate::{TchError, Tensor};
use std::sync::{Arc, Mutex};
#[derive(Debug)]
pub struct Optimizer<T> {
opt: COptimizer,
variables: Arc<Mutex<Variables>>,
variables_in_optimizer: usize,
config: T,
}
pub trait OptimizerConfig
where
Self: std::marker::Sized,
{
fn build_copt(&self, lr: f64) -> Result<COptimizer, TchError>;
fn build(self, vs: &VarStore, lr: f64) -> Result<Optimizer<Self>, TchError> {
let mut opt = self.build_copt(lr)?;
let v = vs.variables_.lock().unwrap();
for var in &v.trainable_variables {
opt.add_parameters(&var.tensor, var.group)?;
}
Ok(Optimizer {
opt,
variables: vs.variables_.clone(),
variables_in_optimizer: v.trainable_variables.len(),
config: self,
})
}
}
#[derive(Debug, Copy, Clone)]
pub struct Sgd {
pub momentum: f64,
pub dampening: f64,
pub wd: f64,
pub nesterov: bool,
}
impl Default for Sgd {
fn default() -> Self {
Sgd {
momentum: 0.,
dampening: 0.,
wd: 0.,
nesterov: false,
}
}
}
pub fn sgd(momentum: f64, dampening: f64, wd: f64, nesterov: bool) -> Sgd {
Sgd {
momentum,
dampening,
wd,
nesterov,
}
}
impl OptimizerConfig for Sgd {
fn build_copt(&self, lr: f64) -> Result<COptimizer, TchError> {
COptimizer::sgd(lr, self.momentum, self.dampening, self.wd, self.nesterov)
}
}
#[derive(Debug, Copy, Clone)]
pub struct Adam {
pub beta1: f64,
pub beta2: f64,
pub wd: f64,
}
impl Default for Adam {
fn default() -> Self {
Adam {
beta1: 0.9,
beta2: 0.999,
wd: 0.,
}
}
}
pub fn adam(beta1: f64, beta2: f64, wd: f64) -> Adam {
Adam { beta1, beta2, wd }
}
impl OptimizerConfig for Adam {
fn build_copt(&self, lr: f64) -> Result<COptimizer, TchError> {
COptimizer::adam(lr, self.beta1, self.beta2, self.wd)
}
}
#[derive(Debug, Copy, Clone)]
pub struct AdamW {
pub beta1: f64,
pub beta2: f64,
pub wd: f64,
}
impl Default for AdamW {
fn default() -> Self {
AdamW {
beta1: 0.9,
beta2: 0.999,
wd: 0.01,
}
}
}
pub fn adamw(beta1: f64, beta2: f64, wd: f64) -> AdamW {
AdamW { beta1, beta2, wd }
}
impl OptimizerConfig for AdamW {
fn build_copt(&self, lr: f64) -> Result<COptimizer, TchError> {
COptimizer::adamw(lr, self.beta1, self.beta2, self.wd)
}
}
#[derive(Debug, Copy, Clone)]
pub struct RmsProp {
pub alpha: f64,
pub eps: f64,
pub wd: f64,
pub momentum: f64,
pub centered: bool,
}
impl Default for RmsProp {
fn default() -> Self {
RmsProp {
alpha: 0.99,
eps: 1e-8,
wd: 0.,
momentum: 0.,
centered: false,
}
}
}
pub fn rms_prop(alpha: f64, eps: f64, wd: f64, momentum: f64, centered: bool) -> RmsProp {
RmsProp {
alpha,
eps,
wd,
momentum,
centered,
}
}
impl OptimizerConfig for RmsProp {
fn build_copt(&self, lr: f64) -> Result<COptimizer, TchError> {
COptimizer::rms_prop(
lr,
self.alpha,
self.eps,
self.wd,
self.momentum,
self.centered,
)
}
}
impl<T> Optimizer<T> {
fn add_missing_variables(&mut self) {
let v = self.variables.lock().unwrap();
if v.trainable_variables.len() > self.variables_in_optimizer {
for var in &v.trainable_variables[self.variables_in_optimizer..] {
self.opt.add_parameters(&var.tensor, var.group).unwrap();
}
self.variables_in_optimizer = v.trainable_variables.len();
}
}
pub fn zero_grad(&mut self) {
self.add_missing_variables();
self.opt.zero_grad().unwrap()
}
pub fn clip_grad_value(&self, max: f64) {
let v = self.variables.lock().unwrap();
for var in v.trainable_variables.iter() {
let _t = var.tensor.grad().clamp_(-max, max);
}
}
pub fn clip_grad_norm(&self, max: f64) {
crate::no_grad(|| {
let v = self.variables.lock().unwrap();
let mut norms = vec![];
for var in v.trainable_variables.iter() {
norms.push(var.tensor.grad().norm());
}
let total_norm = f64::from(Tensor::stack(&norms, 0).norm());
let clip_coef = max / (total_norm + 1e-6);
if clip_coef < 1.0 {
for var in v.trainable_variables.iter() {
let _t = var.tensor.grad().g_mul_1(clip_coef);
}
}
})
}
pub fn step(&mut self) {
self.add_missing_variables();
self.opt.step().unwrap()
}
pub fn backward_step(&mut self, loss: &Tensor) {
self.add_missing_variables();
self.opt.zero_grad().unwrap();
loss.backward();
self.opt.step().unwrap()
}
pub fn backward_step_clip(&mut self, loss: &Tensor, max: f64) {
self.add_missing_variables();
self.opt.zero_grad().unwrap();
loss.backward();
self.clip_grad_value(max);
self.opt.step().unwrap()
}
pub fn backward_step_clip_norm(&mut self, loss: &Tensor, max: f64) {
self.add_missing_variables();
self.opt.zero_grad().unwrap();
loss.backward();
self.clip_grad_norm(max);
self.opt.step().unwrap()
}
pub fn set_lr(&mut self, lr: f64) {
self.opt.set_learning_rate(lr).unwrap()
}
pub fn set_momentum(&mut self, m: f64) {
self.opt.set_momentum(m).unwrap()
}
pub fn set_lr_group(&mut self, group: usize, lr: f64) {
self.opt.set_learning_rate_group(group, lr).unwrap()
}
pub fn set_momentum_group(&mut self, group: usize, m: f64) {
self.opt.set_momentum_group(group, m).unwrap()
}
pub fn trainable_variables(&self) -> Vec<Tensor> {
let variables = self.variables.lock().unwrap();
variables
.trainable_variables
.iter()
.map(|v| v.tensor.shallow_clone())
.collect()
}
pub fn set_weight_decay(&mut self, weight_decay: f64) {
self.opt.set_weight_decay(weight_decay).unwrap()
}
pub fn set_weight_decay_group(&mut self, group: usize, weight_decay: f64) {
self.opt
.set_weight_decay_group(group, weight_decay)
.unwrap()
}
}