1use crate::{
2 adam::{Adam, AdamW},
3 scheduler::LRScheduler,
4 sgd::SGD,
5};
6use trustformers_core::{errors::Result, tensor::Tensor, traits::Optimizer};
7
8pub struct BERTOptimizer {
10 base_optimizer: AdamW,
11 warmup_scheduler: Box<dyn LRScheduler>,
12 #[allow(dead_code)]
13 layer_wise_decay: f32,
14 #[allow(dead_code)]
15 weight_decay_exclusions: Vec<String>,
16 current_step: usize,
17 #[allow(dead_code)]
18 warmup_steps: usize,
19 #[allow(dead_code)]
20 total_steps: usize,
21}
22
23impl BERTOptimizer {
24 pub fn new(
25 learning_rate: f32,
26 warmup_steps: usize,
27 total_steps: usize,
28 layer_wise_decay: f32,
29 ) -> Result<Self> {
30 let base_optimizer = AdamW::new(learning_rate, (0.9, 0.999), 1e-6, 0.01);
31
32 let warmup_scheduler = Box::new(BERTWarmupScheduler::new(
34 learning_rate,
35 warmup_steps,
36 total_steps,
37 ));
38
39 let weight_decay_exclusions = vec![
41 "bias".to_string(),
42 "LayerNorm".to_string(),
43 "layer_norm".to_string(),
44 "ln".to_string(),
45 ];
46
47 Ok(Self {
48 base_optimizer,
49 warmup_scheduler,
50 layer_wise_decay,
51 weight_decay_exclusions,
52 current_step: 0,
53 warmup_steps,
54 total_steps,
55 })
56 }
57
58 #[allow(dead_code)]
60 fn get_layer_wise_lr(&self, param_name: &str, base_lr: f32) -> f32 {
61 if let Some(layer_num) = self.extract_layer_number(param_name) {
63 let decay_factor = self.layer_wise_decay.powi(layer_num as i32);
64 base_lr * decay_factor
65 } else {
66 base_lr
67 }
68 }
69
70 fn extract_layer_number(&self, param_name: &str) -> Option<usize> {
71 if param_name.contains("layer.") {
73 let parts: Vec<&str> = param_name.split('.').collect();
74 for i in 0..parts.len() {
75 if parts[i] == "layer" && i + 1 < parts.len() {
76 if let Ok(layer_num) = parts[i + 1].parse::<usize>() {
77 return Some(layer_num);
78 }
79 }
80 }
81 }
82 None
83 }
84
85 #[allow(dead_code)]
86 fn should_exclude_weight_decay(&self, param_name: &str) -> bool {
87 self.weight_decay_exclusions
88 .iter()
89 .any(|exclusion| param_name.contains(exclusion))
90 }
91}
92
93impl Optimizer for BERTOptimizer {
94 fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
95 self.base_optimizer.update(parameter, grad)
96 }
97
98 fn zero_grad(&mut self) {
99 self.base_optimizer.zero_grad()
100 }
101
102 fn step(&mut self) {
103 self.base_optimizer.step();
104 self.warmup_scheduler.step();
105 self.current_step += 1;
106 }
107
108 fn get_lr(&self) -> f32 {
109 self.base_optimizer.get_lr()
110 }
111
112 fn set_lr(&mut self, lr: f32) {
113 self.base_optimizer.set_lr(lr)
114 }
115}
116
117struct BERTWarmupScheduler {
119 base_lr: f32,
120 warmup_steps: usize,
121 total_steps: usize,
122 current_step: usize,
123}
124
125impl BERTWarmupScheduler {
126 fn new(base_lr: f32, warmup_steps: usize, total_steps: usize) -> Self {
127 Self {
128 base_lr,
129 warmup_steps,
130 total_steps,
131 current_step: 0,
132 }
133 }
134}
135
136impl LRScheduler for BERTWarmupScheduler {
137 fn step(&mut self) {
138 self.current_step += 1;
139 }
140
141 fn get_lr(&self, step: usize) -> f32 {
142 if step < self.warmup_steps {
143 self.base_lr * (step as f32 / self.warmup_steps as f32)
145 } else {
146 let progress =
148 (step - self.warmup_steps) as f32 / (self.total_steps - self.warmup_steps) as f32;
149 self.base_lr * (1.0 - progress).max(0.0)
150 }
151 }
152}
153
154pub struct GANOptimizer {
156 generator_optimizer: Adam,
157 discriminator_optimizer: Adam,
158 spectral_norm: bool,
159 gradient_penalty_weight: f32,
160 #[allow(dead_code)]
161 ttur: bool, d_steps_per_g_step: usize,
163 current_d_steps: usize,
164}
165
166impl GANOptimizer {
167 pub fn new(g_lr: f32, d_lr: f32, spectral_norm: bool, gradient_penalty_weight: f32) -> Self {
168 let generator_optimizer = Adam::new(g_lr, (0.0, 0.999), 1e-8, 0.0);
169 let discriminator_optimizer = Adam::new(d_lr, (0.0, 0.999), 1e-8, 0.0);
170
171 Self {
172 generator_optimizer,
173 discriminator_optimizer,
174 spectral_norm,
175 gradient_penalty_weight,
176 ttur: d_lr != g_lr,
177 d_steps_per_g_step: if d_lr > g_lr { 5 } else { 1 },
178 current_d_steps: 0,
179 }
180 }
181
182 pub fn step_discriminator(
183 &mut self,
184 d_params: &mut [Tensor],
185 d_grads: &[Tensor],
186 ) -> Result<()> {
187 let mut modified_grads = d_grads.to_vec();
189 if self.gradient_penalty_weight > 0.0 {
190 self.apply_gradient_penalty(&mut modified_grads)?;
191 }
192
193 if self.spectral_norm {
195 self.apply_spectral_norm(d_params)?;
196 }
197
198 for (param, grad) in d_params.iter_mut().zip(modified_grads.iter()) {
199 self.discriminator_optimizer.update(param, grad)?;
200 }
201 self.discriminator_optimizer.step();
202 self.current_d_steps += 1;
203 Ok(())
204 }
205
206 pub fn step_generator(&mut self, g_params: &mut [Tensor], g_grads: &[Tensor]) -> Result<()> {
207 if self.current_d_steps >= self.d_steps_per_g_step {
209 for (param, grad) in g_params.iter_mut().zip(g_grads.iter()) {
210 self.generator_optimizer.update(param, grad)?;
211 }
212 self.generator_optimizer.step();
213 self.current_d_steps = 0;
214 }
215 Ok(())
216 }
217
218 fn apply_gradient_penalty(&self, gradients: &mut [Tensor]) -> Result<()> {
219 for grad in gradients.iter_mut() {
221 let grad_norm = self.compute_gradient_norm(grad)?;
222 if grad_norm > 1.0 {
223 let penalty = (grad_norm - 1.0).powi(2) * self.gradient_penalty_weight;
224 *grad = grad.add_scalar(penalty)?;
225 }
226 }
227 Ok(())
228 }
229
230 fn apply_spectral_norm(&self, parameters: &mut [Tensor]) -> Result<()> {
231 for param in parameters.iter_mut() {
233 if param.shape().len() >= 2 {
234 let spectral_norm = self.compute_spectral_norm(param)?;
236 if spectral_norm > 1.0 {
237 *param = param.div_scalar(spectral_norm)?;
238 }
239 }
240 }
241 Ok(())
242 }
243
244 fn compute_gradient_norm(&self, grad: &Tensor) -> Result<f32> {
245 let sum_squares = grad.pow(2.0)?.sum(None, false)?;
247 let norm_tensor = sum_squares.sqrt()?;
248 let norm_data = norm_tensor.data()?;
250 Ok(norm_data[0].sqrt())
251 }
252
253 fn compute_spectral_norm(&self, weight: &Tensor) -> Result<f32> {
254 let weight_data = weight.data()?;
256 let len = weight_data.len();
257
258 if len == 0 {
260 return Ok(0.0);
261 }
262 if len == 1 {
263 return Ok(weight_data[0].abs());
264 }
265
266 if len <= 4 {
268 let frobenius_norm: f32 = weight_data.iter().map(|x| x * x).sum::<f32>().sqrt();
269 return Ok(frobenius_norm);
270 }
271
272 let sqrt_len = (len as f32).sqrt() as usize;
274 let rows = sqrt_len.max(1);
275 let cols = len.div_ceil(rows); let mut v: Vec<f32> = (0..cols).map(|i| ((i % 7) as f32) / 7.0 - 0.5).collect();
279 let mut v_norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
280 if v_norm > 0.0 {
281 for val in &mut v {
282 *val /= v_norm;
283 }
284 }
285
286 for _ in 0..5 {
288 let mut new_v = vec![0.0; rows];
290
291 for i in 0..rows {
293 for j in 0..cols {
294 let idx = i * cols + j;
295 if idx < len && j < v.len() {
296 new_v[i] += weight_data[idx] * v[j];
297 }
298 }
299 }
300
301 v_norm = new_v.iter().map(|x| x * x).sum::<f32>().sqrt();
303 if v_norm > 1e-8 {
304 for item in &mut new_v {
305 *item /= v_norm;
306 }
307 v = new_v;
309 } else {
310 break;
311 }
312 }
313
314 Ok(v_norm.max(1e-8)) }
317}
318
319pub struct RLOptimizer {
321 policy_optimizer: Adam,
322 value_optimizer: Adam,
323 clip_grad_norm: Option<f32>,
324 entropy_coeff: f32,
325 value_loss_coeff: f32,
326 #[allow(dead_code)]
327 max_grad_norm: f32,
328}
329
330impl RLOptimizer {
331 pub fn new(
332 policy_lr: f32,
333 value_lr: f32,
334 entropy_coeff: f32,
335 value_loss_coeff: f32,
336 max_grad_norm: f32,
337 ) -> Self {
338 let policy_optimizer = Adam::new(policy_lr, (0.9, 0.999), 1e-8, 0.0);
339 let value_optimizer = Adam::new(value_lr, (0.9, 0.999), 1e-8, 0.0);
340
341 Self {
342 policy_optimizer,
343 value_optimizer,
344 clip_grad_norm: Some(max_grad_norm),
345 entropy_coeff,
346 value_loss_coeff,
347 max_grad_norm,
348 }
349 }
350
351 pub fn step_policy(&mut self, params: &mut [Tensor], grads: &[Tensor]) -> Result<()> {
352 let mut modified_grads = grads.to_vec();
353
354 if let Some(max_norm) = self.clip_grad_norm {
356 self.clip_gradients(&mut modified_grads, max_norm)?;
357 }
358
359 self.apply_entropy_regularization(&mut modified_grads)?;
361
362 for (param, grad) in params.iter_mut().zip(modified_grads.iter()) {
363 self.policy_optimizer.update(param, grad)?;
364 }
365 self.policy_optimizer.step();
366 Ok(())
367 }
368
369 pub fn step_value(&mut self, params: &mut [Tensor], grads: &[Tensor]) -> Result<()> {
370 let mut modified_grads = grads.to_vec();
371
372 for grad in modified_grads.iter_mut() {
374 *grad = grad.mul_scalar(self.value_loss_coeff)?;
375 }
376
377 if let Some(max_norm) = self.clip_grad_norm {
379 self.clip_gradients(&mut modified_grads, max_norm)?;
380 }
381
382 for (param, grad) in params.iter_mut().zip(modified_grads.iter()) {
383 self.value_optimizer.update(param, grad)?;
384 }
385 self.value_optimizer.step();
386 Ok(())
387 }
388
389 fn clip_gradients(&self, gradients: &mut [Tensor], max_norm: f32) -> Result<()> {
390 let mut total_norm_sq: f32 = 0.0;
392 for grad in gradients.iter() {
393 let grad_norm_sq_tensor = grad.pow(2.0)?.sum(None, false)?;
394 let grad_norm_sq_data = grad_norm_sq_tensor.data()?;
395 total_norm_sq += grad_norm_sq_data[0];
396 }
397
398 let total_norm = total_norm_sq.sqrt();
399
400 if total_norm > max_norm {
401 let clip_factor = max_norm / total_norm;
402 for grad in gradients.iter_mut() {
403 *grad = grad.mul_scalar(clip_factor)?;
404 }
405 }
406
407 Ok(())
408 }
409
410 fn apply_entropy_regularization(&self, gradients: &mut [Tensor]) -> Result<()> {
411 for grad in gradients.iter_mut() {
413 let entropy_bonus = grad.mul_scalar(self.entropy_coeff)?;
414 *grad = grad.sub(&entropy_bonus)?;
415 }
416 Ok(())
417 }
418}
419
420pub struct MetaOptimizer {
422 meta_optimizer: Adam,
423 inner_optimizer: SGD,
424 inner_steps: usize,
425 #[allow(dead_code)]
426 inner_lr: f32,
427 #[allow(dead_code)]
428 meta_lr: f32,
429 first_order: bool, }
431
432impl MetaOptimizer {
433 pub fn new(meta_lr: f32, inner_lr: f32, inner_steps: usize, first_order: bool) -> Self {
434 let meta_optimizer = Adam::new(meta_lr, (0.9, 0.999), 1e-8, 0.0);
435 let inner_optimizer = SGD::new(inner_lr, 0.0, 0.0, false);
436
437 Self {
438 meta_optimizer,
439 inner_optimizer,
440 inner_steps,
441 inner_lr,
442 meta_lr,
443 first_order,
444 }
445 }
446
447 pub fn meta_step(&mut self, params: &mut [Tensor], meta_grads: &[Tensor]) -> Result<()> {
448 for (param, grad) in params.iter_mut().zip(meta_grads.iter()) {
449 self.meta_optimizer.update(param, grad)?;
450 }
451 self.meta_optimizer.step();
452 Ok(())
453 }
454
455 pub fn inner_loop(
456 &mut self,
457 mut params: Vec<Tensor>,
458 task_grads: &[Vec<Tensor>],
459 ) -> Result<Vec<Tensor>> {
460 for step in 0..self.inner_steps {
462 if step < task_grads.len() {
463 let grads = &task_grads[step];
464 for (param, grad) in params.iter_mut().zip(grads.iter()) {
465 self.inner_optimizer.update(param, grad)?;
466 }
467 self.inner_optimizer.step();
468 }
469 }
470 Ok(params)
471 }
472
473 pub fn compute_meta_gradients(
474 &self,
475 original_params: &[Tensor],
476 adapted_params: &[Tensor],
477 meta_loss_grads: &[Tensor],
478 ) -> Result<Vec<Tensor>> {
479 if self.first_order {
480 Ok(meta_loss_grads.to_vec())
482 } else {
483 self.compute_second_order_grads(original_params, adapted_params, meta_loss_grads)
485 }
486 }
487
488 fn compute_second_order_grads(
489 &self,
490 _original_params: &[Tensor],
491 _adapted_params: &[Tensor],
492 meta_loss_grads: &[Tensor],
493 ) -> Result<Vec<Tensor>> {
494 Ok(meta_loss_grads.to_vec())
497 }
498}
499
500pub fn create_bert_optimizer(
502 learning_rate: f32,
503 warmup_steps: usize,
504 total_steps: usize,
505) -> Result<BERTOptimizer> {
506 BERTOptimizer::new(learning_rate, warmup_steps, total_steps, 0.95)
507}
508
509pub fn create_gan_optimizer(g_lr: f32, d_lr: f32, use_spectral_norm: bool) -> GANOptimizer {
510 GANOptimizer::new(g_lr, d_lr, use_spectral_norm, 10.0)
511}
512
513pub fn create_ppo_optimizer(learning_rate: f32, entropy_coeff: f32) -> RLOptimizer {
514 RLOptimizer::new(learning_rate, learning_rate, entropy_coeff, 0.5, 0.5)
515}
516
517pub fn create_maml_optimizer(meta_lr: f32, inner_lr: f32, inner_steps: usize) -> MetaOptimizer {
518 MetaOptimizer::new(meta_lr, inner_lr, inner_steps, false)
519}