Skip to main content

trustformers_optim/
task_specific.rs

1use crate::{
2    adam::{Adam, AdamW},
3    scheduler::LRScheduler,
4    sgd::SGD,
5};
6use trustformers_core::{errors::Result, tensor::Tensor, traits::Optimizer};
7
8/// BERT-specific optimizer with tailored hyperparameters and scheduling
9pub struct BERTOptimizer {
10    base_optimizer: AdamW,
11    warmup_scheduler: Box<dyn LRScheduler>,
12    #[allow(dead_code)]
13    layer_wise_decay: f32,
14    #[allow(dead_code)]
15    weight_decay_exclusions: Vec<String>,
16    current_step: usize,
17    #[allow(dead_code)]
18    warmup_steps: usize,
19    #[allow(dead_code)]
20    total_steps: usize,
21}
22
23impl BERTOptimizer {
24    pub fn new(
25        learning_rate: f32,
26        warmup_steps: usize,
27        total_steps: usize,
28        layer_wise_decay: f32,
29    ) -> Result<Self> {
30        let base_optimizer = AdamW::new(learning_rate, (0.9, 0.999), 1e-6, 0.01);
31
32        // BERT-specific warmup scheduler
33        let warmup_scheduler = Box::new(BERTWarmupScheduler::new(
34            learning_rate,
35            warmup_steps,
36            total_steps,
37        ));
38
39        // Parameters that should not have weight decay (bias, LayerNorm)
40        let weight_decay_exclusions = vec![
41            "bias".to_string(),
42            "LayerNorm".to_string(),
43            "layer_norm".to_string(),
44            "ln".to_string(),
45        ];
46
47        Ok(Self {
48            base_optimizer,
49            warmup_scheduler,
50            layer_wise_decay,
51            weight_decay_exclusions,
52            current_step: 0,
53            warmup_steps,
54            total_steps,
55        })
56    }
57
58    /// Apply layer-wise learning rate decay for deeper layers
59    #[allow(dead_code)]
60    fn get_layer_wise_lr(&self, param_name: &str, base_lr: f32) -> f32 {
61        // Extract layer number from parameter name
62        if let Some(layer_num) = self.extract_layer_number(param_name) {
63            let decay_factor = self.layer_wise_decay.powi(layer_num as i32);
64            base_lr * decay_factor
65        } else {
66            base_lr
67        }
68    }
69
70    fn extract_layer_number(&self, param_name: &str) -> Option<usize> {
71        // Extract layer number from names like "encoder.layer.11.attention.self.query.weight"
72        if param_name.contains("layer.") {
73            let parts: Vec<&str> = param_name.split('.').collect();
74            for i in 0..parts.len() {
75                if parts[i] == "layer" && i + 1 < parts.len() {
76                    if let Ok(layer_num) = parts[i + 1].parse::<usize>() {
77                        return Some(layer_num);
78                    }
79                }
80            }
81        }
82        None
83    }
84
85    #[allow(dead_code)]
86    fn should_exclude_weight_decay(&self, param_name: &str) -> bool {
87        self.weight_decay_exclusions
88            .iter()
89            .any(|exclusion| param_name.contains(exclusion))
90    }
91}
92
93impl Optimizer for BERTOptimizer {
94    fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
95        self.base_optimizer.update(parameter, grad)
96    }
97
98    fn zero_grad(&mut self) {
99        self.base_optimizer.zero_grad()
100    }
101
102    fn step(&mut self) {
103        self.base_optimizer.step();
104        self.warmup_scheduler.step();
105        self.current_step += 1;
106    }
107
108    fn get_lr(&self) -> f32 {
109        self.base_optimizer.get_lr()
110    }
111
112    fn set_lr(&mut self, lr: f32) {
113        self.base_optimizer.set_lr(lr)
114    }
115}
116
117/// BERT warmup scheduler
118struct BERTWarmupScheduler {
119    base_lr: f32,
120    warmup_steps: usize,
121    total_steps: usize,
122    current_step: usize,
123}
124
125impl BERTWarmupScheduler {
126    fn new(base_lr: f32, warmup_steps: usize, total_steps: usize) -> Self {
127        Self {
128            base_lr,
129            warmup_steps,
130            total_steps,
131            current_step: 0,
132        }
133    }
134}
135
136impl LRScheduler for BERTWarmupScheduler {
137    fn step(&mut self) {
138        self.current_step += 1;
139    }
140
141    fn get_lr(&self, step: usize) -> f32 {
142        if step < self.warmup_steps {
143            // Linear warmup
144            self.base_lr * (step as f32 / self.warmup_steps as f32)
145        } else {
146            // Linear decay
147            let progress =
148                (step - self.warmup_steps) as f32 / (self.total_steps - self.warmup_steps) as f32;
149            self.base_lr * (1.0 - progress).max(0.0)
150        }
151    }
152}
153
154/// GAN optimizer with stability improvements
155pub struct GANOptimizer {
156    generator_optimizer: Adam,
157    discriminator_optimizer: Adam,
158    spectral_norm: bool,
159    gradient_penalty_weight: f32,
160    #[allow(dead_code)]
161    ttur: bool, // Two Time-scale Update Rule
162    d_steps_per_g_step: usize,
163    current_d_steps: usize,
164}
165
166impl GANOptimizer {
167    pub fn new(g_lr: f32, d_lr: f32, spectral_norm: bool, gradient_penalty_weight: f32) -> Self {
168        let generator_optimizer = Adam::new(g_lr, (0.0, 0.999), 1e-8, 0.0);
169        let discriminator_optimizer = Adam::new(d_lr, (0.0, 0.999), 1e-8, 0.0);
170
171        Self {
172            generator_optimizer,
173            discriminator_optimizer,
174            spectral_norm,
175            gradient_penalty_weight,
176            ttur: d_lr != g_lr,
177            d_steps_per_g_step: if d_lr > g_lr { 5 } else { 1 },
178            current_d_steps: 0,
179        }
180    }
181
182    pub fn step_discriminator(
183        &mut self,
184        d_params: &mut [Tensor],
185        d_grads: &[Tensor],
186    ) -> Result<()> {
187        // Apply gradient penalty if enabled
188        let mut modified_grads = d_grads.to_vec();
189        if self.gradient_penalty_weight > 0.0 {
190            self.apply_gradient_penalty(&mut modified_grads)?;
191        }
192
193        // Apply spectral normalization if enabled
194        if self.spectral_norm {
195            self.apply_spectral_norm(d_params)?;
196        }
197
198        for (param, grad) in d_params.iter_mut().zip(modified_grads.iter()) {
199            self.discriminator_optimizer.update(param, grad)?;
200        }
201        self.discriminator_optimizer.step();
202        self.current_d_steps += 1;
203        Ok(())
204    }
205
206    pub fn step_generator(&mut self, g_params: &mut [Tensor], g_grads: &[Tensor]) -> Result<()> {
207        // Only update generator after enough discriminator steps
208        if self.current_d_steps >= self.d_steps_per_g_step {
209            for (param, grad) in g_params.iter_mut().zip(g_grads.iter()) {
210                self.generator_optimizer.update(param, grad)?;
211            }
212            self.generator_optimizer.step();
213            self.current_d_steps = 0;
214        }
215        Ok(())
216    }
217
218    fn apply_gradient_penalty(&self, gradients: &mut [Tensor]) -> Result<()> {
219        // Apply gradient penalty to encourage Lipschitz constraint
220        for grad in gradients.iter_mut() {
221            let grad_norm = self.compute_gradient_norm(grad)?;
222            if grad_norm > 1.0 {
223                let penalty = (grad_norm - 1.0).powi(2) * self.gradient_penalty_weight;
224                *grad = grad.add_scalar(penalty)?;
225            }
226        }
227        Ok(())
228    }
229
230    fn apply_spectral_norm(&self, parameters: &mut [Tensor]) -> Result<()> {
231        // Apply spectral normalization to weight matrices
232        for param in parameters.iter_mut() {
233            if param.shape().len() >= 2 {
234                // Only for weight matrices
235                let spectral_norm = self.compute_spectral_norm(param)?;
236                if spectral_norm > 1.0 {
237                    *param = param.div_scalar(spectral_norm)?;
238                }
239            }
240        }
241        Ok(())
242    }
243
244    fn compute_gradient_norm(&self, grad: &Tensor) -> Result<f32> {
245        // Compute L2 norm of gradient
246        let sum_squares = grad.pow(2.0)?.sum(None, false)?;
247        let norm_tensor = sum_squares.sqrt()?;
248        // Extract scalar value from tensor
249        let norm_data = norm_tensor.data()?;
250        Ok(norm_data[0].sqrt())
251    }
252
253    fn compute_spectral_norm(&self, weight: &Tensor) -> Result<f32> {
254        // Spectral norm computation using power iteration method
255        let weight_data = weight.data()?;
256        let len = weight_data.len();
257
258        // Handle edge cases
259        if len == 0 {
260            return Ok(0.0);
261        }
262        if len == 1 {
263            return Ok(weight_data[0].abs());
264        }
265
266        // For very small matrices, use simple Frobenius norm approximation
267        if len <= 4 {
268            let frobenius_norm: f32 = weight_data.iter().map(|x| x * x).sum::<f32>().sqrt();
269            return Ok(frobenius_norm);
270        }
271
272        // Power iteration method for spectral norm (largest singular value)
273        let sqrt_len = (len as f32).sqrt() as usize;
274        let rows = sqrt_len.max(1);
275        let cols = len.div_ceil(rows); // Ceiling division
276
277        // Initialize random vector
278        let mut v: Vec<f32> = (0..cols).map(|i| ((i % 7) as f32) / 7.0 - 0.5).collect();
279        let mut v_norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
280        if v_norm > 0.0 {
281            for val in &mut v {
282                *val /= v_norm;
283            }
284        }
285
286        // Power iteration (simplified - assumes roughly square matrix)
287        for _ in 0..5 {
288            // 5 iterations usually sufficient
289            let mut new_v = vec![0.0; rows];
290
291            // Matrix-vector multiplication: W^T * W * v
292            for i in 0..rows {
293                for j in 0..cols {
294                    let idx = i * cols + j;
295                    if idx < len && j < v.len() {
296                        new_v[i] += weight_data[idx] * v[j];
297                    }
298                }
299            }
300
301            // Compute norm
302            v_norm = new_v.iter().map(|x| x * x).sum::<f32>().sqrt();
303            if v_norm > 1e-8 {
304                for item in &mut new_v {
305                    *item /= v_norm;
306                }
307                // Resize v to match new_v for next iteration
308                v = new_v;
309            } else {
310                break;
311            }
312        }
313
314        // The spectral norm is approximately the final norm
315        Ok(v_norm.max(1e-8)) // Avoid zero values
316    }
317}
318
319/// Reinforcement Learning optimizer with specialized features
320pub struct RLOptimizer {
321    policy_optimizer: Adam,
322    value_optimizer: Adam,
323    clip_grad_norm: Option<f32>,
324    entropy_coeff: f32,
325    value_loss_coeff: f32,
326    #[allow(dead_code)]
327    max_grad_norm: f32,
328}
329
330impl RLOptimizer {
331    pub fn new(
332        policy_lr: f32,
333        value_lr: f32,
334        entropy_coeff: f32,
335        value_loss_coeff: f32,
336        max_grad_norm: f32,
337    ) -> Self {
338        let policy_optimizer = Adam::new(policy_lr, (0.9, 0.999), 1e-8, 0.0);
339        let value_optimizer = Adam::new(value_lr, (0.9, 0.999), 1e-8, 0.0);
340
341        Self {
342            policy_optimizer,
343            value_optimizer,
344            clip_grad_norm: Some(max_grad_norm),
345            entropy_coeff,
346            value_loss_coeff,
347            max_grad_norm,
348        }
349    }
350
351    pub fn step_policy(&mut self, params: &mut [Tensor], grads: &[Tensor]) -> Result<()> {
352        let mut modified_grads = grads.to_vec();
353
354        // Apply gradient clipping
355        if let Some(max_norm) = self.clip_grad_norm {
356            self.clip_gradients(&mut modified_grads, max_norm)?;
357        }
358
359        // Apply entropy regularization
360        self.apply_entropy_regularization(&mut modified_grads)?;
361
362        for (param, grad) in params.iter_mut().zip(modified_grads.iter()) {
363            self.policy_optimizer.update(param, grad)?;
364        }
365        self.policy_optimizer.step();
366        Ok(())
367    }
368
369    pub fn step_value(&mut self, params: &mut [Tensor], grads: &[Tensor]) -> Result<()> {
370        let mut modified_grads = grads.to_vec();
371
372        // Scale value gradients
373        for grad in modified_grads.iter_mut() {
374            *grad = grad.mul_scalar(self.value_loss_coeff)?;
375        }
376
377        // Apply gradient clipping
378        if let Some(max_norm) = self.clip_grad_norm {
379            self.clip_gradients(&mut modified_grads, max_norm)?;
380        }
381
382        for (param, grad) in params.iter_mut().zip(modified_grads.iter()) {
383            self.value_optimizer.update(param, grad)?;
384        }
385        self.value_optimizer.step();
386        Ok(())
387    }
388
389    fn clip_gradients(&self, gradients: &mut [Tensor], max_norm: f32) -> Result<()> {
390        // Compute global gradient norm
391        let mut total_norm_sq: f32 = 0.0;
392        for grad in gradients.iter() {
393            let grad_norm_sq_tensor = grad.pow(2.0)?.sum(None, false)?;
394            let grad_norm_sq_data = grad_norm_sq_tensor.data()?;
395            total_norm_sq += grad_norm_sq_data[0];
396        }
397
398        let total_norm = total_norm_sq.sqrt();
399
400        if total_norm > max_norm {
401            let clip_factor = max_norm / total_norm;
402            for grad in gradients.iter_mut() {
403                *grad = grad.mul_scalar(clip_factor)?;
404            }
405        }
406
407        Ok(())
408    }
409
410    fn apply_entropy_regularization(&self, gradients: &mut [Tensor]) -> Result<()> {
411        // Add entropy bonus to encourage exploration
412        for grad in gradients.iter_mut() {
413            let entropy_bonus = grad.mul_scalar(self.entropy_coeff)?;
414            *grad = grad.sub(&entropy_bonus)?;
415        }
416        Ok(())
417    }
418}
419
420/// Meta-learning optimizer (MAML-style)
421pub struct MetaOptimizer {
422    meta_optimizer: Adam,
423    inner_optimizer: SGD,
424    inner_steps: usize,
425    #[allow(dead_code)]
426    inner_lr: f32,
427    #[allow(dead_code)]
428    meta_lr: f32,
429    first_order: bool, // Use first-order approximation
430}
431
432impl MetaOptimizer {
433    pub fn new(meta_lr: f32, inner_lr: f32, inner_steps: usize, first_order: bool) -> Self {
434        let meta_optimizer = Adam::new(meta_lr, (0.9, 0.999), 1e-8, 0.0);
435        let inner_optimizer = SGD::new(inner_lr, 0.0, 0.0, false);
436
437        Self {
438            meta_optimizer,
439            inner_optimizer,
440            inner_steps,
441            inner_lr,
442            meta_lr,
443            first_order,
444        }
445    }
446
447    pub fn meta_step(&mut self, params: &mut [Tensor], meta_grads: &[Tensor]) -> Result<()> {
448        for (param, grad) in params.iter_mut().zip(meta_grads.iter()) {
449            self.meta_optimizer.update(param, grad)?;
450        }
451        self.meta_optimizer.step();
452        Ok(())
453    }
454
455    pub fn inner_loop(
456        &mut self,
457        mut params: Vec<Tensor>,
458        task_grads: &[Vec<Tensor>],
459    ) -> Result<Vec<Tensor>> {
460        // Perform inner loop adaptation for a specific task
461        for step in 0..self.inner_steps {
462            if step < task_grads.len() {
463                let grads = &task_grads[step];
464                for (param, grad) in params.iter_mut().zip(grads.iter()) {
465                    self.inner_optimizer.update(param, grad)?;
466                }
467                self.inner_optimizer.step();
468            }
469        }
470        Ok(params)
471    }
472
473    pub fn compute_meta_gradients(
474        &self,
475        original_params: &[Tensor],
476        adapted_params: &[Tensor],
477        meta_loss_grads: &[Tensor],
478    ) -> Result<Vec<Tensor>> {
479        if self.first_order {
480            // First-order approximation (ignore second derivatives)
481            Ok(meta_loss_grads.to_vec())
482        } else {
483            // Second-order gradients through inner loop
484            self.compute_second_order_grads(original_params, adapted_params, meta_loss_grads)
485        }
486    }
487
488    fn compute_second_order_grads(
489        &self,
490        _original_params: &[Tensor],
491        _adapted_params: &[Tensor],
492        meta_loss_grads: &[Tensor],
493    ) -> Result<Vec<Tensor>> {
494        // Simplified second-order gradient computation
495        // In practice, would use automatic differentiation
496        Ok(meta_loss_grads.to_vec())
497    }
498}
499
500/// Factory functions for creating task-specific optimizers
501pub fn create_bert_optimizer(
502    learning_rate: f32,
503    warmup_steps: usize,
504    total_steps: usize,
505) -> Result<BERTOptimizer> {
506    BERTOptimizer::new(learning_rate, warmup_steps, total_steps, 0.95)
507}
508
509pub fn create_gan_optimizer(g_lr: f32, d_lr: f32, use_spectral_norm: bool) -> GANOptimizer {
510    GANOptimizer::new(g_lr, d_lr, use_spectral_norm, 10.0)
511}
512
513pub fn create_ppo_optimizer(learning_rate: f32, entropy_coeff: f32) -> RLOptimizer {
514    RLOptimizer::new(learning_rate, learning_rate, entropy_coeff, 0.5, 0.5)
515}
516
517pub fn create_maml_optimizer(meta_lr: f32, inner_lr: f32, inner_steps: usize) -> MetaOptimizer {
518    MetaOptimizer::new(meta_lr, inner_lr, inner_steps, false)
519}