Skip to main content

yscv_optim/
adagrad.rs

1use std::collections::HashMap;
2use std::collections::hash_map::Entry;
3
4use yscv_autograd::{Graph, NodeId};
5use yscv_tensor::Tensor;
6
7use super::validate::{validate_epsilon, validate_lr};
8use super::{LearningRate, OptimError};
9
10#[derive(Debug, Clone)]
11struct AdagradState {
12    sum_sq: Tensor,
13}
14
15impl AdagradState {
16    fn new(shape: &[usize]) -> Result<Self, OptimError> {
17        Ok(Self {
18            sum_sq: Tensor::zeros(shape.to_vec())?,
19        })
20    }
21
22    fn reset(&mut self, shape: &[usize]) -> Result<(), OptimError> {
23        *self = Self::new(shape)?;
24        Ok(())
25    }
26}
27
28/// Adagrad optimizer with optional L2 weight decay.
29#[derive(Debug, Clone)]
30pub struct Adagrad {
31    lr: f32,
32    epsilon: f32,
33    weight_decay: f32,
34    state: HashMap<u64, AdagradState>,
35}
36
37impl Adagrad {
38    /// Creates Adagrad with required learning rate.
39    pub fn new(lr: f32) -> Result<Self, OptimError> {
40        validate_lr(lr)?;
41        Ok(Self {
42            lr,
43            epsilon: 1e-10,
44            weight_decay: 0.0,
45            state: HashMap::new(),
46        })
47    }
48
49    /// Sets epsilon value, must be finite and `> 0`.
50    pub fn with_epsilon(mut self, epsilon: f32) -> Result<Self, OptimError> {
51        validate_epsilon(epsilon)?;
52        self.epsilon = epsilon;
53        Ok(self)
54    }
55
56    /// Sets L2 weight decay factor in `[0, +inf)`.
57    pub fn with_weight_decay(mut self, weight_decay: f32) -> Result<Self, OptimError> {
58        if !weight_decay.is_finite() || weight_decay < 0.0 {
59            return Err(OptimError::InvalidWeightDecay { weight_decay });
60        }
61        self.weight_decay = weight_decay;
62        Ok(self)
63    }
64
65    /// Drops optimizer state (for example when restarting training).
66    pub fn clear_state(&mut self) {
67        self.state.clear();
68    }
69
70    /// Returns current learning rate.
71    pub fn learning_rate(&self) -> f32 {
72        self.lr
73    }
74
75    /// Overrides current learning rate.
76    pub fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
77        validate_lr(lr)?;
78        self.lr = lr;
79        Ok(())
80    }
81
82    /// Applies one update to raw tensor weights.
83    pub fn step(
84        &mut self,
85        parameter_id: u64,
86        weights: &mut Tensor,
87        grad: &Tensor,
88    ) -> Result<(), OptimError> {
89        if weights.shape() != grad.shape() {
90            return Err(OptimError::ShapeMismatch {
91                weights: weights.shape().to_vec(),
92                grad: grad.shape().to_vec(),
93            });
94        }
95
96        let state = match self.state.entry(parameter_id) {
97            Entry::Occupied(entry) => entry.into_mut(),
98            Entry::Vacant(entry) => entry.insert(AdagradState::new(weights.shape())?),
99        };
100        if state.sum_sq.shape() != weights.shape() {
101            state.reset(weights.shape())?;
102        }
103
104        let sum_sq = state.sum_sq.data_mut();
105        let grad_values = grad.data();
106        let weights_data = weights.data_mut();
107
108        for index in 0..weights_data.len() {
109            let grad_value = grad_values[index] + self.weight_decay * weights_data[index];
110            sum_sq[index] += grad_value * grad_value;
111            weights_data[index] -= self.lr * grad_value / (sum_sq[index].sqrt() + self.epsilon);
112        }
113
114        Ok(())
115    }
116
117    /// Applies one update to a trainable graph node by its `NodeId`.
118    pub fn step_graph_node(&mut self, graph: &mut Graph, node: NodeId) -> Result<(), OptimError> {
119        if !graph.requires_grad(node)? {
120            return Ok(());
121        }
122
123        let grad = match graph.grad(node)? {
124            Some(grad) => grad.clone(),
125            None => return Err(OptimError::MissingGradient { node: node.0 }),
126        };
127        let weights = graph.value_mut(node)?;
128        self.step(node.0 as u64, weights, &grad)
129    }
130}
131
132impl LearningRate for Adagrad {
133    fn learning_rate(&self) -> f32 {
134        Adagrad::learning_rate(self)
135    }
136
137    fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
138        Adagrad::set_learning_rate(self, lr)
139    }
140}