scirs2_series/advanced_training_modules/
meta_learning.rs1use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::fmt::Debug;
9
10use super::config::TaskData;
11use crate::error::Result;
12
13#[derive(Debug)]
15pub struct MAML<F: Float + Debug + scirs2_core::ndarray::ScalarOperand> {
16 parameters: Array2<F>,
18 meta_lr: F,
20 inner_lr: F,
22 inner_steps: usize,
24 input_dim: usize,
26 hidden_dim: usize,
27 output_dim: usize,
28}
29
30impl<F: Float + Debug + Clone + FromPrimitive + scirs2_core::ndarray::ScalarOperand> MAML<F> {
31 pub fn new(
33 input_dim: usize,
34 hidden_dim: usize,
35 output_dim: usize,
36 meta_lr: F,
37 inner_lr: F,
38 inner_steps: usize,
39 ) -> Self {
40 let total_params =
42 input_dim * hidden_dim + hidden_dim + hidden_dim * output_dim + output_dim;
43 let scale = F::from(2.0).unwrap() / F::from(input_dim + output_dim).unwrap();
44 let std_dev = scale.sqrt();
45
46 let mut parameters = Array2::zeros((1, total_params));
47 for i in 0..total_params {
48 let val = ((i * 17) % 1000) as f64 / 1000.0 - 0.5;
49 parameters[[0, i]] = F::from(val).unwrap() * std_dev;
50 }
51
52 Self {
53 parameters,
54 meta_lr,
55 inner_lr,
56 inner_steps,
57 input_dim,
58 hidden_dim,
59 output_dim,
60 }
61 }
62
63 pub fn meta_train(&mut self, tasks: &[TaskData<F>]) -> Result<F> {
65 let mut meta_gradients = Array2::zeros(self.parameters.dim());
66 let mut total_loss = F::zero();
67
68 for task in tasks {
69 let adapted_params = self.inner_loop_adaptation(task)?;
71
72 let task_loss = self.compute_meta_loss(&adapted_params, task)?;
74 let task_gradient = self.compute_meta_gradient(&adapted_params, task)?;
75
76 meta_gradients = meta_gradients + task_gradient;
77 total_loss = total_loss + task_loss;
78 }
79
80 let num_tasks = F::from(tasks.len()).unwrap();
82 meta_gradients = meta_gradients / num_tasks;
83 total_loss = total_loss / num_tasks;
84
85 self.parameters = self.parameters.clone() - meta_gradients * self.meta_lr;
87
88 Ok(total_loss)
89 }
90
91 fn inner_loop_adaptation(&self, task: &TaskData<F>) -> Result<Array2<F>> {
93 let mut adapted_params = self.parameters.clone();
94
95 for _ in 0..self.inner_steps {
96 let _loss = self.forward(&adapted_params, &task.support_x, &task.support_y)?;
97 let gradients = self.compute_gradients(&adapted_params, task)?;
98 adapted_params = adapted_params - gradients * self.inner_lr;
99 }
100
101 Ok(adapted_params)
102 }
103
104 fn forward(&self, params: &Array2<F>, inputs: &Array2<F>, targets: &Array2<F>) -> Result<F> {
106 let predictions = self.predict(params, inputs)?;
107
108 let mut loss = F::zero();
110 let (batch_size, _) = predictions.dim();
111
112 for i in 0..batch_size {
113 for j in 0..self.output_dim {
114 let diff = predictions[[i, j]] - targets[[i, j]];
115 loss = loss + diff * diff;
116 }
117 }
118
119 Ok(loss / F::from(batch_size).unwrap())
120 }
121
122 fn predict(&self, params: &Array2<F>, inputs: &Array2<F>) -> Result<Array2<F>> {
124 let (batch_size, _) = inputs.dim();
125
126 let (w1, b1, w2, b2) = self.extract_weights(params);
128
129 let mut hidden = Array2::zeros((batch_size, self.hidden_dim));
131
132 for i in 0..batch_size {
134 for j in 0..self.hidden_dim {
135 let mut sum = b1[j];
136 for k in 0..self.input_dim {
137 sum = sum + inputs[[i, k]] * w1[[j, k]];
138 }
139 hidden[[i, j]] = self.relu(sum); }
141 }
142
143 let mut output = Array2::zeros((batch_size, self.output_dim));
145 for i in 0..batch_size {
146 for j in 0..self.output_dim {
147 let mut sum = b2[j];
148 for k in 0..self.hidden_dim {
149 sum = sum + hidden[[i, k]] * w2[[j, k]];
150 }
151 output[[i, j]] = sum; }
153 }
154
155 Ok(output)
156 }
157
158 fn extract_weights(&self, params: &Array2<F>) -> (Array2<F>, Array1<F>, Array2<F>, Array1<F>) {
160 let param_vec = params.row(0);
161 let mut idx = 0;
162
163 let mut w1 = Array2::zeros((self.hidden_dim, self.input_dim));
165 for i in 0..self.hidden_dim {
166 for j in 0..self.input_dim {
167 w1[[i, j]] = param_vec[idx];
168 idx += 1;
169 }
170 }
171
172 let mut b1 = Array1::zeros(self.hidden_dim);
174 for i in 0..self.hidden_dim {
175 b1[i] = param_vec[idx];
176 idx += 1;
177 }
178
179 let mut w2 = Array2::zeros((self.output_dim, self.hidden_dim));
181 for i in 0..self.output_dim {
182 for j in 0..self.hidden_dim {
183 w2[[i, j]] = param_vec[idx];
184 idx += 1;
185 }
186 }
187
188 let mut b2 = Array1::zeros(self.output_dim);
190 for i in 0..self.output_dim {
191 b2[i] = param_vec[idx];
192 idx += 1;
193 }
194
195 (w1, b1, w2, b2)
196 }
197
198 fn relu(&self, x: F) -> F {
200 x.max(F::zero())
201 }
202
203 fn compute_gradients(&self, params: &Array2<F>, task: &TaskData<F>) -> Result<Array2<F>> {
205 let epsilon = F::from(1e-5).unwrap();
206 let mut gradients = Array2::zeros(params.dim());
207
208 let base_loss = self.forward(params, &task.support_x, &task.support_y)?;
209
210 for i in 0..params.ncols() {
211 let mut perturbed_params = params.clone();
212 perturbed_params[[0, i]] = perturbed_params[[0, i]] + epsilon;
213
214 let perturbed_loss =
215 self.forward(&perturbed_params, &task.support_x, &task.support_y)?;
216 gradients[[0, i]] = (perturbed_loss - base_loss) / epsilon;
217 }
218
219 Ok(gradients)
220 }
221
222 fn compute_meta_gradient(
224 &self,
225 adapted_params: &Array2<F>,
226 task: &TaskData<F>,
227 ) -> Result<Array2<F>> {
228 let _meta_loss = self.forward(adapted_params, &task.query_x, &task.query_y)?;
230 self.compute_gradients(
231 adapted_params,
232 &TaskData {
233 support_x: task.query_x.clone(),
234 support_y: task.query_y.clone(),
235 query_x: task.query_x.clone(),
236 query_y: task.query_y.clone(),
237 },
238 )
239 }
240
241 fn compute_meta_loss(&self, adapted_params: &Array2<F>, task: &TaskData<F>) -> Result<F> {
243 self.forward(adapted_params, &task.query_x, &task.query_y)
244 }
245
246 pub fn fast_adapt(&self, support_x: &Array2<F>, support_y: &Array2<F>) -> Result<Array2<F>> {
248 let task = TaskData {
249 support_x: support_x.clone(),
250 support_y: support_y.clone(),
251 query_x: support_x.clone(),
252 query_y: support_y.clone(),
253 };
254
255 self.inner_loop_adaptation(&task)
256 }
257}