1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
mod sgd;
mod adam;
mod rmsprop;

pub use self::sgd::*;
pub use self::adam::*;
pub use self::rmsprop::*;

use crate::backend::{Backend, BackendAxpys};
use crate::optimizer::Optimizer;
use core::marker::PhantomData;

pub struct WeightDecay<N, B, O>
    where B: Backend<N>,
          O: Optimizer<N, B>
{
    lamda: f32,
    optimizer: O,
    _m: PhantomData<fn(N, B, O)>,
}

impl<N, B, O> WeightDecay<N, B, O>
    where B: Backend<N>,
          O: Optimizer<N, B> 
{
    pub fn new(lamda: f32, optimizer: O) -> Self {
        Self {
            lamda,
            optimizer,
            _m: Default::default(),
        }
    }
}

impl<N, B, O> Optimizer<N, B> for WeightDecay<N, B, O> 
    where B: Backend<N> + BackendAxpys<N>,
          O: Optimizer<N, B>
{
    type Context = O::Context;

    #[inline]
    fn update_params(&self, backend: &B, ctx: &mut Self::Context, params: &mut B::Tensor, grads: &mut B::Tensor) {
        backend.axpys(grads, backend.scalar_f32(self.lamda), params);

        self.optimizer.update_params(backend, ctx, params, grads);
    }
}