scirs2_series/advanced_training_modules/
optimization.rs

1//! Meta-Optimization Algorithms
2//!
3//! This module implements meta-optimization techniques including learned optimizers
4//! that can adaptively optimize neural network parameters based on the optimization history.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::fmt::Debug;
9
10use crate::error::Result;
11
12/// Meta-Optimizer using LSTM to generate parameter updates
13#[derive(Debug)]
14pub struct MetaOptimizer<F: Float + Debug + scirs2_core::ndarray::ScalarOperand> {
15    /// LSTM parameters for the optimizer
16    #[allow(dead_code)]
17    lstm_params: Array2<F>,
18    /// Hidden state size
19    hidden_size: usize,
20    /// Input dimension (gradient + other features)
21    input_dim: usize,
22    /// Current LSTM hidden state
23    hidden_state: Array1<F>,
24    /// Current LSTM cell state
25    cell_state: Array1<F>,
26}
27
28impl<F: Float + Debug + Clone + FromPrimitive + scirs2_core::ndarray::ScalarOperand>
29    MetaOptimizer<F>
30{
31    /// Create new meta-optimizer
32    pub fn new(input_dim: usize, hidden_size: usize) -> Self {
33        // Initialize LSTM parameters
34        let param_count = 4 * hidden_size * (input_dim + hidden_size) + 4 * hidden_size; // 4 gates
35        let mut lstm_params = Array2::zeros((1, param_count));
36
37        let scale = F::from(1.0).unwrap() / F::from(hidden_size).unwrap().sqrt();
38        for i in 0..param_count {
39            let val = ((i * 79) % 1000) as f64 / 1000.0 - 0.5;
40            lstm_params[[0, i]] = F::from(val).unwrap() * scale;
41        }
42
43        let hidden_state = Array1::zeros(hidden_size);
44        let cell_state = Array1::zeros(hidden_size);
45
46        Self {
47            lstm_params,
48            hidden_size,
49            input_dim,
50            hidden_state,
51            cell_state,
52        }
53    }
54
55    /// Generate parameter update using meta-optimizer
56    pub fn generate_update(
57        &mut self,
58        gradient: F,
59        loss_history: &[F],
60        step_count: usize,
61    ) -> Result<F> {
62        // Prepare input features
63        let mut input = Array1::zeros(self.input_dim);
64        input[0] = gradient;
65
66        if self.input_dim > 1 && !loss_history.is_empty() {
67            input[1] = loss_history[loss_history.len() - 1];
68        }
69
70        if self.input_dim > 2 {
71            input[2] = F::from(step_count).unwrap();
72        }
73
74        // LSTM forward pass
75        let (new_hidden, new_cell) = self.lstm_forward(&input)?;
76        self.hidden_state = new_hidden.clone();
77        self.cell_state = new_cell;
78
79        // Generate parameter update (use first output as update)
80        Ok(new_hidden[0])
81    }
82
83    /// LSTM forward pass
84    fn lstm_forward(&self, input: &Array1<F>) -> Result<(Array1<F>, Array1<F>)> {
85        // Extract LSTM weights (simplified implementation)
86        let combined_input = self.combine_input_hidden(input);
87
88        // Compute gates (simplified)
89        let forget_gate = self.sigmoid(combined_input[0]);
90        let input_gate = self.sigmoid(combined_input[1]);
91        let candidate_gate = self.tanh(combined_input[2]);
92        let output_gate = self.sigmoid(combined_input[3]);
93
94        // Update cell state
95        let mut new_cell_state = Array1::zeros(self.hidden_size);
96        for i in 0..self.hidden_size {
97            new_cell_state[i] = forget_gate * self.cell_state[i] + input_gate * candidate_gate;
98        }
99
100        // Update hidden state
101        let mut new_hidden_state = Array1::zeros(self.hidden_size);
102        for i in 0..self.hidden_size {
103            new_hidden_state[i] = output_gate * self.tanh(new_cell_state[i]);
104        }
105
106        Ok((new_hidden_state, new_cell_state))
107    }
108
109    /// Combine input and hidden state
110    fn combine_input_hidden(&self, input: &Array1<F>) -> Array1<F> {
111        // Simplified combination - just use input values for gates
112        let mut combined = Array1::zeros(4);
113        for i in 0..4.min(input.len()) {
114            combined[i] = input[i.min(input.len() - 1)];
115        }
116        combined
117    }
118
119    /// Sigmoid activation
120    fn sigmoid(&self, x: F) -> F {
121        F::one() / (F::one() + (-x).exp())
122    }
123
124    /// Hyperbolic tangent activation
125    fn tanh(&self, x: F) -> F {
126        x.tanh()
127    }
128
129    /// Reset optimizer state
130    pub fn reset(&mut self) {
131        self.hidden_state = Array1::zeros(self.hidden_size);
132        self.cell_state = Array1::zeros(self.hidden_size);
133    }
134
135    /// Train meta-optimizer on optimization tasks
136    pub fn meta_train(&mut self, optimization_problems: &[OptimizationProblem<F>]) -> Result<F> {
137        let mut total_loss = F::zero();
138
139        for problem in optimization_problems {
140            self.reset();
141
142            let mut current_params = problem.initial_params.clone();
143            let mut loss_history = Vec::new();
144
145            // Simulate optimization steps
146            for step in 0..problem.max_steps {
147                // Compute gradient
148                let gradient = self.compute_simple_gradient(&current_params, problem)?;
149
150                // Generate update using meta-optimizer
151                let update = self.generate_update(gradient, &loss_history, step)?;
152
153                // Apply update
154                current_params = current_params + update;
155
156                // Compute loss
157                let loss = self.evaluate_objective(&current_params, problem)?;
158                loss_history.push(loss);
159                total_loss = total_loss + loss;
160            }
161        }
162
163        Ok(total_loss / F::from(optimization_problems.len()).unwrap())
164    }
165
166    /// Compute simple gradient (placeholder)
167    fn compute_simple_gradient(
168        &self,
169        params: &Array1<F>,
170        problem: &OptimizationProblem<F>,
171    ) -> Result<F> {
172        // Simplified gradient computation
173        if !params.is_empty() && !problem.target.is_empty() {
174            Ok(params[0] - problem.target[0])
175        } else {
176            Ok(F::zero())
177        }
178    }
179
180    /// Evaluate objective function
181    fn evaluate_objective(
182        &self,
183        params: &Array1<F>,
184        problem: &OptimizationProblem<F>,
185    ) -> Result<F> {
186        // Simple quadratic objective
187        let mut loss = F::zero();
188        for i in 0..params.len().min(problem.target.len()) {
189            let diff = params[i] - problem.target[i];
190            loss = loss + diff * diff;
191        }
192        Ok(loss)
193    }
194
195    /// Get current hidden state
196    pub fn hidden_state(&self) -> &Array1<F> {
197        &self.hidden_state
198    }
199
200    /// Get current cell state
201    pub fn cell_state(&self) -> &Array1<F> {
202        &self.cell_state
203    }
204
205    /// Set hidden state
206    pub fn set_hidden_state(&mut self, state: Array1<F>) -> Result<()> {
207        if state.len() != self.hidden_size {
208            return Err(crate::error::TimeSeriesError::InvalidOperation(
209                "Hidden state size mismatch".to_string(),
210            ));
211        }
212        self.hidden_state = state;
213        Ok(())
214    }
215
216    /// Set cell state
217    pub fn set_cell_state(&mut self, state: Array1<F>) -> Result<()> {
218        if state.len() != self.hidden_size {
219            return Err(crate::error::TimeSeriesError::InvalidOperation(
220                "Cell state size mismatch".to_string(),
221            ));
222        }
223        self.cell_state = state;
224        Ok(())
225    }
226
227    /// Get optimizer dimensions
228    pub fn dimensions(&self) -> (usize, usize) {
229        (self.input_dim, self.hidden_size)
230    }
231
232    /// Apply meta-optimizer to optimize parameters
233    pub fn optimize_parameters(
234        &mut self,
235        initial_params: &Array1<F>,
236        target: &Array1<F>,
237        max_steps: usize,
238    ) -> Result<(Array1<F>, Vec<F>)> {
239        self.reset();
240
241        let mut current_params = initial_params.clone();
242        let mut loss_history = Vec::new();
243
244        let problem = OptimizationProblem {
245            initial_params: initial_params.clone(),
246            target: target.clone(),
247            max_steps,
248        };
249
250        for step in 0..max_steps {
251            // Compute gradient
252            let gradient = self.compute_simple_gradient(&current_params, &problem)?;
253
254            // Generate update using meta-optimizer
255            let update = self.generate_update(gradient, &loss_history, step)?;
256
257            // Apply update
258            current_params = current_params + update;
259
260            // Compute and record loss
261            let loss = self.evaluate_objective(&current_params, &problem)?;
262            loss_history.push(loss);
263        }
264
265        Ok((current_params, loss_history))
266    }
267
268    /// Generate vectorized updates for multiple parameters
269    pub fn generate_vectorized_update(
270        &mut self,
271        gradients: &Array1<F>,
272        loss_history: &[F],
273        step_count: usize,
274    ) -> Result<Array1<F>> {
275        let mut updates = Array1::zeros(gradients.len());
276
277        for (i, &gradient) in gradients.iter().enumerate() {
278            let update = self.generate_update(gradient, loss_history, step_count + i)?;
279            updates[i] = update;
280        }
281
282        Ok(updates)
283    }
284}
285
286/// Optimization problem for meta-optimizer training
287#[derive(Debug, Clone)]
288pub struct OptimizationProblem<F: Float + Debug> {
289    /// Initial parameters
290    pub initial_params: Array1<F>,
291    /// Target parameters
292    pub target: Array1<F>,
293    /// Maximum optimization steps
294    pub max_steps: usize,
295}
296
297impl<F: Float + Debug> OptimizationProblem<F> {
298    /// Create a new optimization problem
299    pub fn new(initial_params: Array1<F>, target: Array1<F>, max_steps: usize) -> Self {
300        Self {
301            initial_params,
302            target,
303            max_steps,
304        }
305    }
306
307    /// Create a quadratic optimization problem
308    pub fn quadratic(dim: usize, max_steps: usize) -> Self
309    where
310        F: FromPrimitive,
311    {
312        let initial_params = Array1::from_vec(
313            (0..dim)
314                .map(|i| F::from((i * 13) % 100).unwrap() / F::from(100.0).unwrap())
315                .collect(),
316        );
317        let target = Array1::zeros(dim);
318
319        Self {
320            initial_params,
321            target,
322            max_steps,
323        }
324    }
325
326    /// Create a random optimization problem
327    pub fn random(dim: usize, max_steps: usize) -> Self
328    where
329        F: FromPrimitive,
330    {
331        let initial_params = Array1::from_vec(
332            (0..dim)
333                .map(|i| F::from((i * 17 + 23) % 200).unwrap() / F::from(100.0).unwrap() - F::one())
334                .collect(),
335        );
336        let target = Array1::from_vec(
337            (0..dim)
338                .map(|i| F::from((i * 19 + 37) % 100).unwrap() / F::from(200.0).unwrap())
339                .collect(),
340        );
341
342        Self {
343            initial_params,
344            target,
345            max_steps,
346        }
347    }
348
349    /// Get the dimension of the problem
350    pub fn dimension(&self) -> usize {
351        self.initial_params.len()
352    }
353
354    /// Evaluate the objective function at given parameters
355    pub fn evaluate(&self, params: &Array1<F>) -> F {
356        let mut loss = F::zero();
357        for i in 0..params.len().min(self.target.len()) {
358            let diff = params[i] - self.target[i];
359            loss = loss + diff * diff;
360        }
361        loss
362    }
363
364    /// Compute the gradient at given parameters
365    pub fn gradient(&self, params: &Array1<F>) -> Array1<F> {
366        let mut grad = Array1::zeros(params.len());
367        for i in 0..params.len().min(self.target.len()) {
368            grad[i] = F::from(2.0).unwrap() * (params[i] - self.target[i]);
369        }
370        grad
371    }
372
373    /// Check if the problem has converged
374    pub fn has_converged(&self, params: &Array1<F>, tolerance: F) -> bool {
375        self.evaluate(params) < tolerance
376    }
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382    use approx::assert_abs_diff_eq;
383
384    #[test]
385    fn test_meta_optimizer_creation() {
386        let meta_opt = MetaOptimizer::<f64>::new(3, 5);
387        let (input_dim, hidden_size) = meta_opt.dimensions();
388
389        assert_eq!(input_dim, 3);
390        assert_eq!(hidden_size, 5);
391        assert_eq!(meta_opt.hidden_state().len(), 5);
392        assert_eq!(meta_opt.cell_state().len(), 5);
393    }
394
395    #[test]
396    fn test_meta_optimizer_update_generation() {
397        let mut meta_opt = MetaOptimizer::<f64>::new(3, 4);
398
399        let gradient = 0.1;
400        let loss_history = vec![1.0, 0.8, 0.6];
401        let step_count = 5;
402
403        let update = meta_opt
404            .generate_update(gradient, &loss_history, step_count)
405            .unwrap();
406        assert!(update.is_finite());
407    }
408
409    #[test]
410    fn test_meta_optimizer_reset() {
411        let mut meta_opt = MetaOptimizer::<f64>::new(2, 3);
412
413        // Generate some updates to change state
414        let _ = meta_opt.generate_update(0.5, &[1.0], 1).unwrap();
415
416        // Reset should zero out the states
417        meta_opt.reset();
418
419        for &val in meta_opt.hidden_state().iter() {
420            assert_abs_diff_eq!(val, 0.0, epsilon = 1e-10);
421        }
422        for &val in meta_opt.cell_state().iter() {
423            assert_abs_diff_eq!(val, 0.0, epsilon = 1e-10);
424        }
425    }
426
427    #[test]
428    fn test_optimization_problem_creation() {
429        let initial = Array1::from_vec(vec![1.0, 2.0, 3.0]);
430        let target = Array1::from_vec(vec![0.0, 0.0, 0.0]);
431        let problem = OptimizationProblem::new(initial, target, 100);
432
433        assert_eq!(problem.dimension(), 3);
434        assert_eq!(problem.max_steps, 100);
435    }
436
437    #[test]
438    fn test_optimization_problem_evaluation() {
439        let initial = Array1::from_vec(vec![1.0, 2.0]);
440        let target = Array1::from_vec(vec![0.0, 0.0]);
441        let problem = OptimizationProblem::new(initial, target, 50);
442
443        let params = Array1::from_vec(vec![1.0, 1.0]);
444        let loss = problem.evaluate(&params);
445        let expected = (1.0 - 0.0).powi(2) + (1.0 - 0.0).powi(2);
446        assert_abs_diff_eq!(loss, expected, epsilon = 1e-10);
447    }
448
449    #[test]
450    fn test_optimization_problem_gradient() {
451        let initial = Array1::from_vec(vec![2.0, 3.0]);
452        let target = Array1::from_vec(vec![1.0, 1.0]);
453        let problem = OptimizationProblem::new(initial, target, 50);
454
455        let params = Array1::from_vec(vec![2.0, 3.0]);
456        let gradient = problem.gradient(&params);
457
458        // For quadratic loss f(x) = (x - target)^2, gradient is 2(x - target)
459        assert_abs_diff_eq!(gradient[0], 2.0 * (2.0 - 1.0), epsilon = 1e-10);
460        assert_abs_diff_eq!(gradient[1], 2.0 * (3.0 - 1.0), epsilon = 1e-10);
461    }
462
463    #[test]
464    fn test_optimization_problem_convergence() {
465        let initial = Array1::from_vec(vec![1.0]);
466        let target = Array1::from_vec(vec![0.0]);
467        let problem = OptimizationProblem::new(initial, target, 50);
468
469        let converged_params = Array1::from_vec(vec![0.001]);
470        let not_converged_params = Array1::from_vec(vec![0.5]);
471
472        assert!(problem.has_converged(&converged_params, 0.01));
473        assert!(!problem.has_converged(&not_converged_params, 0.01));
474    }
475
476    #[test]
477    fn test_quadratic_optimization_problem() {
478        let problem = OptimizationProblem::<f64>::quadratic(3, 100);
479
480        assert_eq!(problem.dimension(), 3);
481        assert_eq!(problem.max_steps, 100);
482
483        // Target should be zeros
484        for &val in problem.target.iter() {
485            assert_abs_diff_eq!(val, 0.0, epsilon = 1e-10);
486        }
487    }
488
489    #[test]
490    fn test_random_optimization_problem() {
491        let problem = OptimizationProblem::<f64>::random(4, 200);
492
493        assert_eq!(problem.dimension(), 4);
494        assert_eq!(problem.max_steps, 200);
495        assert_eq!(problem.initial_params.len(), 4);
496        assert_eq!(problem.target.len(), 4);
497    }
498
499    #[test]
500    fn test_meta_optimizer_state_management() {
501        let mut meta_opt = MetaOptimizer::<f64>::new(2, 3);
502
503        let new_hidden = Array1::from_vec(vec![1.0, 2.0, 3.0]);
504        let new_cell = Array1::from_vec(vec![0.5, 1.5, 2.5]);
505
506        meta_opt.set_hidden_state(new_hidden.clone()).unwrap();
507        meta_opt.set_cell_state(new_cell.clone()).unwrap();
508
509        for (i, &val) in meta_opt.hidden_state().iter().enumerate() {
510            assert_abs_diff_eq!(val, new_hidden[i], epsilon = 1e-10);
511        }
512        for (i, &val) in meta_opt.cell_state().iter().enumerate() {
513            assert_abs_diff_eq!(val, new_cell[i], epsilon = 1e-10);
514        }
515    }
516
517    #[test]
518    fn test_meta_optimizer_state_validation() {
519        let mut meta_opt = MetaOptimizer::<f64>::new(2, 3);
520
521        // Try to set state with wrong dimensions
522        let wrong_state = Array1::from_vec(vec![1.0, 2.0]); // Should be size 3
523
524        let result = meta_opt.set_hidden_state(wrong_state.clone());
525        assert!(result.is_err());
526
527        let result = meta_opt.set_cell_state(wrong_state);
528        assert!(result.is_err());
529    }
530
531    #[test]
532    fn test_optimize_parameters() {
533        let mut meta_opt = MetaOptimizer::<f64>::new(3, 4);
534
535        let initial = Array1::from_vec(vec![2.0, 3.0]);
536        let target = Array1::from_vec(vec![0.0, 0.0]);
537
538        let (final_params, loss_history) =
539            meta_opt.optimize_parameters(&initial, &target, 10).unwrap();
540
541        assert_eq!(final_params.len(), 2);
542        assert_eq!(loss_history.len(), 10);
543
544        // Loss should generally decrease (though with the simplified optimizer, this might not always be true)
545        assert!(loss_history.iter().all(|&loss| loss.is_finite()));
546    }
547
548    #[test]
549    fn test_vectorized_update_generation() {
550        let mut meta_opt = MetaOptimizer::<f64>::new(3, 4);
551
552        let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
553        let loss_history = vec![1.0, 0.8];
554
555        let updates = meta_opt
556            .generate_vectorized_update(&gradients, &loss_history, 5)
557            .unwrap();
558
559        assert_eq!(updates.len(), 3);
560        for &update in updates.iter() {
561            assert!(update.is_finite());
562        }
563    }
564
565    #[test]
566    fn test_lstm_forward_pass() {
567        let meta_opt = MetaOptimizer::<f64>::new(2, 3);
568        let input = Array1::from_vec(vec![0.5, -0.3]);
569
570        let (hidden, cell) = meta_opt.lstm_forward(&input).unwrap();
571
572        assert_eq!(hidden.len(), 3);
573        assert_eq!(cell.len(), 3);
574
575        for &val in hidden.iter() {
576            assert!(val.is_finite());
577        }
578        for &val in cell.iter() {
579            assert!(val.is_finite());
580        }
581    }
582
583    #[test]
584    fn test_activation_functions() {
585        let meta_opt = MetaOptimizer::<f64>::new(1, 1);
586
587        // Test sigmoid
588        let sigmoid_result = meta_opt.sigmoid(0.0);
589        assert_abs_diff_eq!(sigmoid_result, 0.5, epsilon = 1e-10);
590
591        let sigmoid_pos = meta_opt.sigmoid(1000.0); // Should be close to 1
592        assert!(sigmoid_pos > 0.99);
593
594        let sigmoid_neg = meta_opt.sigmoid(-1000.0); // Should be close to 0
595        assert!(sigmoid_neg < 0.01);
596
597        // Test tanh
598        let tanh_result = meta_opt.tanh(0.0);
599        assert_abs_diff_eq!(tanh_result, 0.0, epsilon = 1e-10);
600
601        let tanh_pos = meta_opt.tanh(1000.0); // Should be close to 1
602        assert!(tanh_pos > 0.99);
603
604        let tanh_neg = meta_opt.tanh(-1000.0); // Should be close to -1
605        assert!(tanh_neg < -0.99);
606    }
607}