scirs2_series/advanced_training_modules/
meta_learning.rs

1//! Meta-learning algorithms for few-shot time series forecasting
2//!
3//! This module implements Model-Agnostic Meta-Learning (MAML) and related
4//! meta-learning algorithms for rapid adaptation to new time series tasks.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::fmt::Debug;
9
10use super::config::TaskData;
11use crate::error::Result;
12
13/// Model-Agnostic Meta-Learning (MAML) for few-shot time series forecasting
14#[derive(Debug)]
15pub struct MAML<F: Float + Debug + scirs2_core::ndarray::ScalarOperand> {
16    /// Base model parameters
17    parameters: Array2<F>,
18    /// Meta-learning rate
19    meta_lr: F,
20    /// Inner loop learning rate
21    inner_lr: F,
22    /// Number of inner gradient steps
23    inner_steps: usize,
24    /// Model dimensions
25    input_dim: usize,
26    hidden_dim: usize,
27    output_dim: usize,
28}
29
30impl<F: Float + Debug + Clone + FromPrimitive + scirs2_core::ndarray::ScalarOperand> MAML<F> {
31    /// Create new MAML instance
32    pub fn new(
33        input_dim: usize,
34        hidden_dim: usize,
35        output_dim: usize,
36        meta_lr: F,
37        inner_lr: F,
38        inner_steps: usize,
39    ) -> Self {
40        // Initialize parameters using Xavier initialization
41        let total_params =
42            input_dim * hidden_dim + hidden_dim + hidden_dim * output_dim + output_dim;
43        let scale = F::from(2.0).unwrap() / F::from(input_dim + output_dim).unwrap();
44        let std_dev = scale.sqrt();
45
46        let mut parameters = Array2::zeros((1, total_params));
47        for i in 0..total_params {
48            let val = ((i * 17) % 1000) as f64 / 1000.0 - 0.5;
49            parameters[[0, i]] = F::from(val).unwrap() * std_dev;
50        }
51
52        Self {
53            parameters,
54            meta_lr,
55            inner_lr,
56            inner_steps,
57            input_dim,
58            hidden_dim,
59            output_dim,
60        }
61    }
62
63    /// Meta-training step with multiple tasks
64    pub fn meta_train(&mut self, tasks: &[TaskData<F>]) -> Result<F> {
65        let mut meta_gradients = Array2::zeros(self.parameters.dim());
66        let mut total_loss = F::zero();
67
68        for task in tasks {
69            // Inner loop adaptation
70            let adapted_params = self.inner_loop_adaptation(task)?;
71
72            // Compute meta-gradient
73            let task_loss = self.compute_meta_loss(&adapted_params, task)?;
74            let task_gradient = self.compute_meta_gradient(&adapted_params, task)?;
75
76            meta_gradients = meta_gradients + task_gradient;
77            total_loss = total_loss + task_loss;
78        }
79
80        // Meta-update
81        let num_tasks = F::from(tasks.len()).unwrap();
82        meta_gradients = meta_gradients / num_tasks;
83        total_loss = total_loss / num_tasks;
84
85        // Update meta-parameters
86        self.parameters = self.parameters.clone() - meta_gradients * self.meta_lr;
87
88        Ok(total_loss)
89    }
90
91    /// Inner loop adaptation for a single task
92    fn inner_loop_adaptation(&self, task: &TaskData<F>) -> Result<Array2<F>> {
93        let mut adapted_params = self.parameters.clone();
94
95        for _ in 0..self.inner_steps {
96            let _loss = self.forward(&adapted_params, &task.support_x, &task.support_y)?;
97            let gradients = self.compute_gradients(&adapted_params, task)?;
98            adapted_params = adapted_params - gradients * self.inner_lr;
99        }
100
101        Ok(adapted_params)
102    }
103
104    /// Forward pass through neural network
105    fn forward(&self, params: &Array2<F>, inputs: &Array2<F>, targets: &Array2<F>) -> Result<F> {
106        let predictions = self.predict(params, inputs)?;
107
108        // Mean squared error loss
109        let mut loss = F::zero();
110        let (batch_size, _) = predictions.dim();
111
112        for i in 0..batch_size {
113            for j in 0..self.output_dim {
114                let diff = predictions[[i, j]] - targets[[i, j]];
115                loss = loss + diff * diff;
116            }
117        }
118
119        Ok(loss / F::from(batch_size).unwrap())
120    }
121
122    /// Make predictions using current parameters
123    fn predict(&self, params: &Array2<F>, inputs: &Array2<F>) -> Result<Array2<F>> {
124        let (batch_size, _) = inputs.dim();
125
126        // Extract weight matrices from flattened parameters
127        let (w1, b1, w2, b2) = self.extract_weights(params);
128
129        // Forward pass: input -> hidden -> output
130        let mut hidden = Array2::zeros((batch_size, self.hidden_dim));
131
132        // Input to hidden layer
133        for i in 0..batch_size {
134            for j in 0..self.hidden_dim {
135                let mut sum = b1[j];
136                for k in 0..self.input_dim {
137                    sum = sum + inputs[[i, k]] * w1[[j, k]];
138                }
139                hidden[[i, j]] = self.relu(sum); // ReLU activation
140            }
141        }
142
143        // Hidden to output layer
144        let mut output = Array2::zeros((batch_size, self.output_dim));
145        for i in 0..batch_size {
146            for j in 0..self.output_dim {
147                let mut sum = b2[j];
148                for k in 0..self.hidden_dim {
149                    sum = sum + hidden[[i, k]] * w2[[j, k]];
150                }
151                output[[i, j]] = sum; // Linear output
152            }
153        }
154
155        Ok(output)
156    }
157
158    /// Extract weight matrices from flattened parameter vector
159    fn extract_weights(&self, params: &Array2<F>) -> (Array2<F>, Array1<F>, Array2<F>, Array1<F>) {
160        let param_vec = params.row(0);
161        let mut idx = 0;
162
163        // W1: input_dim x hidden_dim
164        let mut w1 = Array2::zeros((self.hidden_dim, self.input_dim));
165        for i in 0..self.hidden_dim {
166            for j in 0..self.input_dim {
167                w1[[i, j]] = param_vec[idx];
168                idx += 1;
169            }
170        }
171
172        // b1: hidden_dim
173        let mut b1 = Array1::zeros(self.hidden_dim);
174        for i in 0..self.hidden_dim {
175            b1[i] = param_vec[idx];
176            idx += 1;
177        }
178
179        // W2: hidden_dim x output_dim
180        let mut w2 = Array2::zeros((self.output_dim, self.hidden_dim));
181        for i in 0..self.output_dim {
182            for j in 0..self.hidden_dim {
183                w2[[i, j]] = param_vec[idx];
184                idx += 1;
185            }
186        }
187
188        // b2: output_dim
189        let mut b2 = Array1::zeros(self.output_dim);
190        for i in 0..self.output_dim {
191            b2[i] = param_vec[idx];
192            idx += 1;
193        }
194
195        (w1, b1, w2, b2)
196    }
197
198    /// ReLU activation function
199    fn relu(&self, x: F) -> F {
200        x.max(F::zero())
201    }
202
203    /// Compute gradients (simplified numerical differentiation)
204    fn compute_gradients(&self, params: &Array2<F>, task: &TaskData<F>) -> Result<Array2<F>> {
205        let epsilon = F::from(1e-5).unwrap();
206        let mut gradients = Array2::zeros(params.dim());
207
208        let base_loss = self.forward(params, &task.support_x, &task.support_y)?;
209
210        for i in 0..params.ncols() {
211            let mut perturbed_params = params.clone();
212            perturbed_params[[0, i]] = perturbed_params[[0, i]] + epsilon;
213
214            let perturbed_loss =
215                self.forward(&perturbed_params, &task.support_x, &task.support_y)?;
216            gradients[[0, i]] = (perturbed_loss - base_loss) / epsilon;
217        }
218
219        Ok(gradients)
220    }
221
222    /// Compute meta-gradient for meta-learning update
223    fn compute_meta_gradient(
224        &self,
225        adapted_params: &Array2<F>,
226        task: &TaskData<F>,
227    ) -> Result<Array2<F>> {
228        // Simplified meta-gradient computation
229        let _meta_loss = self.forward(adapted_params, &task.query_x, &task.query_y)?;
230        self.compute_gradients(
231            adapted_params,
232            &TaskData {
233                support_x: task.query_x.clone(),
234                support_y: task.query_y.clone(),
235                query_x: task.query_x.clone(),
236                query_y: task.query_y.clone(),
237            },
238        )
239    }
240
241    /// Compute meta-loss on query set
242    fn compute_meta_loss(&self, adapted_params: &Array2<F>, task: &TaskData<F>) -> Result<F> {
243        self.forward(adapted_params, &task.query_x, &task.query_y)
244    }
245
246    /// Fast adaptation for new task (few-shot learning)
247    pub fn fast_adapt(&self, support_x: &Array2<F>, support_y: &Array2<F>) -> Result<Array2<F>> {
248        let task = TaskData {
249            support_x: support_x.clone(),
250            support_y: support_y.clone(),
251            query_x: support_x.clone(),
252            query_y: support_y.clone(),
253        };
254
255        self.inner_loop_adaptation(&task)
256    }
257}