1use std::collections::HashMap;
2use std::collections::hash_map::Entry;
3
4use yscv_autograd::{Graph, NodeId};
5use yscv_tensor::Tensor;
6
7use super::validate::{validate_lr, validate_momentum};
8use super::{LearningRate, OptimError};
9
10#[derive(Debug, Clone)]
15pub struct Lars {
16 base_lr: f32,
17 momentum: f32,
18 weight_decay: f32,
19 trust_coefficient: f32,
20 velocity: HashMap<u64, Tensor>,
21}
22
23impl Lars {
24 pub fn new(base_lr: f32) -> Result<Self, OptimError> {
26 validate_lr(base_lr)?;
27 Ok(Self {
28 base_lr,
29 momentum: 0.0,
30 weight_decay: 0.0,
31 trust_coefficient: 0.001,
32 velocity: HashMap::new(),
33 })
34 }
35
36 pub fn with_momentum(mut self, momentum: f32) -> Result<Self, OptimError> {
38 validate_momentum(momentum)?;
39 self.momentum = momentum;
40 Ok(self)
41 }
42
43 pub fn with_weight_decay(mut self, weight_decay: f32) -> Result<Self, OptimError> {
45 if !weight_decay.is_finite() || weight_decay < 0.0 {
46 return Err(OptimError::InvalidWeightDecay { weight_decay });
47 }
48 self.weight_decay = weight_decay;
49 Ok(self)
50 }
51
52 pub fn with_trust_coefficient(mut self, trust_coefficient: f32) -> Result<Self, OptimError> {
54 if !trust_coefficient.is_finite() || trust_coefficient <= 0.0 {
55 return Err(OptimError::InvalidEpsilon {
56 epsilon: trust_coefficient,
57 });
58 }
59 self.trust_coefficient = trust_coefficient;
60 Ok(self)
61 }
62
63 pub fn clear_state(&mut self) {
65 self.velocity.clear();
66 }
67
68 pub fn learning_rate(&self) -> f32 {
70 self.base_lr
71 }
72
73 pub fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
75 validate_lr(lr)?;
76 self.base_lr = lr;
77 Ok(())
78 }
79
80 pub fn step(
82 &mut self,
83 parameter_id: u64,
84 weights: &mut Tensor,
85 grad: &Tensor,
86 ) -> Result<(), OptimError> {
87 if weights.shape() != grad.shape() {
88 return Err(OptimError::ShapeMismatch {
89 weights: weights.shape().to_vec(),
90 grad: grad.shape().to_vec(),
91 });
92 }
93
94 let w_data = weights.data();
96 let g_data = grad.data();
97
98 let w_norm = w_data.iter().map(|x| x * x).sum::<f32>().sqrt();
99 let g_norm = g_data.iter().map(|x| x * x).sum::<f32>().sqrt();
100
101 let local_lr = if w_norm > 0.0 && g_norm > 0.0 {
103 self.trust_coefficient * w_norm / (g_norm + self.weight_decay * w_norm)
104 } else {
105 1.0
106 };
107
108 let mut g_with_wd = g_data.to_vec();
110 if self.weight_decay != 0.0 {
111 for (gv, wv) in g_with_wd.iter_mut().zip(w_data.iter()) {
112 *gv += self.weight_decay * *wv;
113 }
114 }
115
116 let effective_lr = local_lr * self.base_lr;
117
118 let velocity = match self.velocity.entry(parameter_id) {
120 Entry::Occupied(entry) => entry.into_mut(),
121 Entry::Vacant(entry) => entry.insert(Tensor::zeros(weights.shape().to_vec())?),
122 };
123
124 if velocity.shape() != weights.shape() {
125 *velocity = Tensor::zeros(weights.shape().to_vec())?;
126 }
127
128 let v_data = velocity.data_mut();
129 let weights_data = weights.data_mut();
130
131 for i in 0..weights_data.len() {
132 v_data[i] = self.momentum * v_data[i] + effective_lr * g_with_wd[i];
133 weights_data[i] -= v_data[i];
134 }
135
136 Ok(())
137 }
138
139 pub fn step_graph_node(&mut self, graph: &mut Graph, node: NodeId) -> Result<(), OptimError> {
141 if !graph.requires_grad(node)? {
142 return Ok(());
143 }
144
145 let grad = match graph.grad(node)? {
146 Some(grad) => grad.clone(),
147 None => return Err(OptimError::MissingGradient { node: node.0 }),
148 };
149 let weights = graph.value_mut(node)?;
150 self.step(node.0 as u64, weights, &grad)
151 }
152}
153
154impl LearningRate for Lars {
155 fn learning_rate(&self) -> f32 {
156 Lars::learning_rate(self)
157 }
158
159 fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
160 Lars::set_learning_rate(self, lr)
161 }
162}