1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
mod sgd; mod adam; mod rmsprop; pub use self::sgd::*; pub use self::adam::*; pub use self::rmsprop::*; use crate::backend::{Backend, BackendAxpys}; use crate::optimizer::Optimizer; use core::marker::PhantomData; pub struct WeightDecay<N, B, O> where B: Backend<N>, O: Optimizer<N, B> { lamda: f32, optimizer: O, _m: PhantomData<fn(N, B, O)>, } impl<N, B, O> WeightDecay<N, B, O> where B: Backend<N>, O: Optimizer<N, B> { pub fn new(lamda: f32, optimizer: O) -> Self { Self { lamda, optimizer, _m: Default::default(), } } } impl<N, B, O> Optimizer<N, B> for WeightDecay<N, B, O> where B: Backend<N> + BackendAxpys<N>, O: Optimizer<N, B> { type Context = O::Context; #[inline] fn update_params(&self, backend: &B, ctx: &mut Self::Context, params: &mut B::Tensor, grads: &mut B::Tensor) { backend.axpys(grads, backend.scalar_f32(self.lamda), params); self.optimizer.update_params(backend, ctx, params, grads); } }