1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
//! Optimization module.
//!
//! Contains a number of optimizers.
use std::cell::Cell;
use {ParameterNode, Variable};
mod adagrad;
mod adam;
mod barrier;
mod sgd;

pub use self::adagrad::Adagrad;
pub use self::adam::Adam;
use self::barrier::SynchronizationBarrier;
pub use self::sgd::SGD;

/// Core trait implemented by all optimizer methods.
pub trait Optimizer {
    /// Perform a single SGD step.
    fn step(&self, parameters: &[Variable<ParameterNode>]);
}

/// Trait implemented by synchronizable optimizers.
///
/// Using a set of synchronized optimizers guarantees that parameter
/// updates will always happen in the same order, guaranteeing reproducible
/// results at the price of some performance relative to asynchronous parallel
/// optimization.
pub trait Synchronizable {
    /// Synchronize this optimizer, producing a set of synchronized optimimzers
    /// to be used by individual fitting threads.
    fn synchronized(&self, num_threads: usize) -> Vec<SynchronizedOptimizer<Self>>
    where
        Self: Sized,
    {
        self.synchronized_with_step(num_threads, 8)
    }
    /// Synchronize this optimizer, producing a set of synchronized optimimzers
    /// to be used by individual fitting threads. The threads will synchonize
    /// their updates every `step_size` steps.
    fn synchronized_with_step(
        &self,
        num_threads: usize,
        step_size: usize,
    ) -> Vec<SynchronizedOptimizer<Self>>
    where
        Self: Sized;
}

/// Synchronized optimizer wrapper.
pub struct SynchronizedOptimizer<'a, T: 'a> {
    step_size: usize,
    num_updates: Cell<usize>,
    optimizer: &'a T,
    barrier_guard: barrier::SynchronizationBarrierGuard,
}

impl<'a, T: 'a> SynchronizedOptimizer<'a, T> {
    fn new(
        optimizer: &'a T,
        barrier_guard: barrier::SynchronizationBarrierGuard,
        step_size: usize,
    ) -> Self {
        SynchronizedOptimizer {
            step_size: step_size,
            num_updates: Cell::new(0),
            optimizer: optimizer,
            barrier_guard: barrier_guard,
        }
    }
}

impl<'a, T> Optimizer for SynchronizedOptimizer<'a, T>
where
    T: Optimizer,
{
    fn step(&self, parameters: &[Variable<ParameterNode>]) {
        self.num_updates.set(self.num_updates.get() + 1);

        if self.num_updates.get() == self.step_size {
            let _barrier = self.barrier_guard.synchronize();
            self.optimizer.step(parameters);

            self.num_updates.set(0);
        }
    }
}

impl<T> Synchronizable for T
where
    T: Optimizer + Sized,
{
    fn synchronized_with_step(
        &self,
        num_threads: usize,
        step_size: usize,
    ) -> Vec<SynchronizedOptimizer<T>> {
        let barrier = SynchronizationBarrier::default();

        (0..num_threads)
            .map(|_| SynchronizedOptimizer::new(self, barrier.register_thread(), step_size))
            .collect()
    }
}

macro_rules! impl_optimizer_enum {
    ($(($tag:ident, $type:ty)),*) => {
        /// Enum containing all optimizers.
        ///
        /// Makes runtime switching between optimizers slightly more ergonomic.
        pub enum Optimizers {
            $(
                #[allow(missing_docs)]
                $tag($type),
            )*
        }

        impl Optimizer for Optimizers {
            fn step(&self, parameters: &[Variable<ParameterNode>]) {
                match self {
                    $(
                        Optimizers::$tag(val) => val.step(parameters),
                    )*
                }
            }
        }
    }
}

impl_optimizer_enum!((SGD, SGD), (Adagrad, Adagrad), (Adam, Adam));