1use anyhow::Result;
2use scirs2_core::ndarray::{s, Array2}; use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::{Arc, RwLock};
6
7pub trait MetaLearningAlgorithm {
9 fn meta_update(&mut self, task_batch: &TaskBatch) -> Result<MetaUpdateResult>;
10 fn adapt(&self, support_set: &TaskData, adaptation_steps: usize) -> Result<ModelParameters>;
11 fn evaluate(&self, params: &ModelParameters, query_set: &TaskData) -> Result<f32>;
12}
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct MAMLConfig {
17 pub meta_lr: f32,
19 pub inner_lr: f32,
21 pub adaptation_steps: usize,
23 pub meta_batch_size: usize,
25 pub first_order: bool,
27 pub grad_clip: f32,
29 pub learn_inner_lrs: bool,
31}
32
33impl Default for MAMLConfig {
34 fn default() -> Self {
35 Self {
36 meta_lr: 0.001,
37 inner_lr: 0.01,
38 adaptation_steps: 5,
39 meta_batch_size: 16,
40 first_order: false,
41 grad_clip: 10.0,
42 learn_inner_lrs: false,
43 }
44 }
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct ReptileConfig {
50 pub meta_lr: f32,
52 pub inner_lr: f32,
54 pub adaptation_steps: usize,
56 pub meta_batch_size: usize,
58 pub grad_clip: f32,
60}
61
62impl Default for ReptileConfig {
63 fn default() -> Self {
64 Self {
65 meta_lr: 0.001,
66 inner_lr: 0.01,
67 adaptation_steps: 5,
68 meta_batch_size: 16,
69 grad_clip: 10.0,
70 }
71 }
72}
73
74#[derive(Debug, Clone)]
76pub struct ModelParameters {
77 pub parameters: HashMap<String, Array2<f32>>,
79 pub shapes: HashMap<String, Vec<usize>>,
81}
82
83impl Default for ModelParameters {
84 fn default() -> Self {
85 Self::new()
86 }
87}
88
89impl ModelParameters {
90 pub fn new() -> Self {
91 Self {
92 parameters: HashMap::new(),
93 shapes: HashMap::new(),
94 }
95 }
96
97 pub fn add_parameter(&mut self, name: String, tensor: Array2<f32>) {
99 let shape = tensor.shape().to_vec();
100 self.shapes.insert(name.clone(), shape);
101 self.parameters.insert(name, tensor);
102 }
103
104 pub fn get_parameter(&self, name: &str) -> Option<&Array2<f32>> {
106 self.parameters.get(name)
107 }
108
109 pub fn get_parameter_mut(&mut self, name: &str) -> Option<&mut Array2<f32>> {
111 self.parameters.get_mut(name)
112 }
113
114 pub fn clone_parameters(&self) -> Self {
116 Self {
117 parameters: self.parameters.clone(),
118 shapes: self.shapes.clone(),
119 }
120 }
121
122 pub fn update_with_gradients(&mut self, gradients: &Self, learning_rate: f32) -> Result<()> {
124 for (name, param) in &mut self.parameters {
125 if let Some(grad) = gradients.get_parameter(name) {
126 *param = param.clone() - learning_rate * grad;
127 }
128 }
129 Ok(())
130 }
131
132 pub fn subtract(&self, other: &Self) -> Result<Self> {
134 let mut result = Self::new();
135
136 for (name, param) in &self.parameters {
137 if let Some(other_param) = other.get_parameter(name) {
138 let diff = param - other_param;
139 result.add_parameter(name.clone(), diff);
140 }
141 }
142
143 Ok(result)
144 }
145
146 pub fn add(&self, other: &Self) -> Result<Self> {
148 let mut result = Self::new();
149
150 for (name, param) in &self.parameters {
151 if let Some(other_param) = other.get_parameter(name) {
152 let sum = param + other_param;
153 result.add_parameter(name.clone(), sum);
154 }
155 }
156
157 Ok(result)
158 }
159
160 pub fn scale(&self, scalar: f32) -> Self {
162 let mut result = Self::new();
163
164 for (name, param) in &self.parameters {
165 let scaled = param * scalar;
166 result.add_parameter(name.clone(), scaled);
167 }
168
169 result
170 }
171}
172
173#[derive(Debug, Clone)]
175pub struct TaskData {
176 pub inputs: Array2<f32>,
178 pub targets: Array2<f32>,
180 pub task_id: String,
182}
183
184impl TaskData {
185 pub fn new(inputs: Array2<f32>, targets: Array2<f32>, task_id: String) -> Self {
186 Self {
187 inputs,
188 targets,
189 task_id,
190 }
191 }
192
193 pub fn batch_size(&self) -> usize {
195 self.inputs.nrows()
196 }
197
198 pub fn split_batches(&self, batch_size: usize) -> Vec<TaskData> {
200 let total_samples = self.batch_size();
201 let mut batches = Vec::new();
202
203 for start in (0..total_samples).step_by(batch_size) {
204 let end = (start + batch_size).min(total_samples);
205 let batch_inputs = self.inputs.slice(s![start..end, ..]).to_owned();
206 let batch_targets = self.targets.slice(s![start..end, ..]).to_owned();
207
208 batches.push(TaskData::new(
209 batch_inputs,
210 batch_targets,
211 format!("{}_batch_{}", self.task_id, start / batch_size),
212 ));
213 }
214
215 batches
216 }
217}
218
219#[derive(Debug)]
221pub struct TaskBatch {
222 pub support_sets: Vec<TaskData>,
224 pub query_sets: Vec<TaskData>,
226}
227
228impl TaskBatch {
229 pub fn new(support_sets: Vec<TaskData>, query_sets: Vec<TaskData>) -> Result<Self> {
230 if support_sets.len() != query_sets.len() {
231 return Err(anyhow::anyhow!(
232 "Support and query sets must have same length"
233 ));
234 }
235 Ok(Self {
236 support_sets,
237 query_sets,
238 })
239 }
240
241 pub fn num_tasks(&self) -> usize {
243 self.support_sets.len()
244 }
245}
246
247#[derive(Debug)]
249pub struct MetaUpdateResult {
250 pub meta_loss: f32,
252 pub task_losses: Vec<f32>,
254 pub grad_norm: f32,
256 pub updated_parameters: ModelParameters,
258}
259
260pub struct MAMLTrainer {
262 config: MAMLConfig,
263 meta_parameters: Arc<RwLock<ModelParameters>>,
264 #[allow(dead_code)]
265 optimizer_state: HashMap<String, Array2<f32>>, meta_step: usize,
267}
268
269impl MAMLTrainer {
270 pub fn new(config: MAMLConfig, initial_parameters: ModelParameters) -> Self {
271 Self {
272 config,
273 meta_parameters: Arc::new(RwLock::new(initial_parameters)),
274 optimizer_state: HashMap::new(),
275 meta_step: 0,
276 }
277 }
278
279 fn compute_inner_gradients(
281 &self,
282 parameters: &ModelParameters,
283 task_data: &TaskData,
284 ) -> Result<ModelParameters> {
285 let mut gradients = ModelParameters::new();
287
288 for (name, param) in ¶meters.parameters {
289 let grad = self.compute_parameter_gradient(param, task_data)?;
291 gradients.add_parameter(name.clone(), grad);
292 }
293
294 Ok(gradients)
295 }
296
297 fn compute_parameter_gradient(
299 &self,
300 param: &Array2<f32>,
301 data: &TaskData,
302 ) -> Result<Array2<f32>> {
303 let eps = 1e-5f32;
304 let mut gradients = Array2::zeros(param.raw_dim());
305 let _original_loss = self.compute_loss_for_parameter(param, data)?;
306
307 for ((i, j), param_val) in param.indexed_iter() {
309 let mut param_plus = param.clone();
311 param_plus[[i, j]] = param_val + eps;
312 let loss_plus = self.compute_loss_for_parameter(¶m_plus, data)?;
313
314 let mut param_minus = param.clone();
316 param_minus[[i, j]] = param_val - eps;
317 let loss_minus = self.compute_loss_for_parameter(¶m_minus, data)?;
318
319 gradients[[i, j]] = (loss_plus - loss_minus) / (2.0 * eps);
321 }
322
323 Ok(gradients)
324 }
325
326 fn compute_loss_for_parameter(&self, param: &Array2<f32>, data: &TaskData) -> Result<f32> {
328 let predictions = if param.ncols() == data.inputs.ncols() {
331 data.inputs.dot(param)
333 } else if param.shape() == [1, data.targets.ncols()] {
334 Array2::from_shape_fn((data.inputs.nrows(), param.ncols()), |(_, j)| param[[0, j]])
337 } else {
338 data.inputs.clone()
340 };
341
342 let diff = &predictions - &data.targets;
344 let mse = diff.mapv(|x| x * x).mean().unwrap_or(0.0);
345
346 Ok(mse)
347 }
348
349 fn inner_loop_adaptation(
351 &self,
352 initial_params: &ModelParameters,
353 support_set: &TaskData,
354 ) -> Result<ModelParameters> {
355 let mut adapted_params = initial_params.clone_parameters();
356
357 for _step in 0..self.config.adaptation_steps {
358 let gradients = self.compute_inner_gradients(&adapted_params, support_set)?;
359 let lr = if self.config.learn_inner_lrs {
360 self.config.inner_lr
362 } else {
363 self.config.inner_lr
364 };
365
366 adapted_params.update_with_gradients(&gradients, lr)?;
367 }
368
369 Ok(adapted_params)
370 }
371
372 fn compute_meta_gradients(&self, task_batch: &TaskBatch) -> Result<(ModelParameters, f32)> {
374 let meta_params = self.meta_parameters.read().expect("lock should not be poisoned");
375 let mut meta_gradients = ModelParameters::new();
376 let mut total_meta_loss = 0.0;
377
378 for (name, param) in &meta_params.parameters {
380 meta_gradients.add_parameter(name.clone(), Array2::zeros(param.raw_dim()));
381 }
382
383 for (support_set, query_set) in task_batch.support_sets.iter().zip(&task_batch.query_sets) {
385 let adapted_params = self.inner_loop_adaptation(&meta_params, support_set)?;
387
388 let query_loss = self.compute_query_loss(&adapted_params, query_set)?;
390 total_meta_loss += query_loss;
391
392 let task_meta_grads = if self.config.first_order {
394 meta_params.subtract(&adapted_params)?
396 } else {
397 self.compute_second_order_gradients(&meta_params, &adapted_params, query_set)?
399 };
400
401 for (name, grad) in &task_meta_grads.parameters {
403 if let Some(meta_grad) = meta_gradients.get_parameter_mut(name) {
404 *meta_grad = meta_grad.clone() + grad;
405 }
406 }
407 }
408
409 let num_tasks = task_batch.num_tasks() as f32;
411 for grad in meta_gradients.parameters.values_mut() {
412 *grad = grad.clone() / num_tasks;
413 }
414
415 total_meta_loss /= num_tasks;
416
417 Ok((meta_gradients, total_meta_loss))
418 }
419
420 fn compute_query_loss(&self, params: &ModelParameters, query_set: &TaskData) -> Result<f32> {
422 let predictions = self.forward_pass(params, &query_set.inputs)?;
424
425 let loss = if query_set.targets.ncols() == 1 {
427 self.compute_mse_loss(&predictions, &query_set.targets)?
429 } else {
430 self.compute_cross_entropy_loss(&predictions, &query_set.targets)?
432 };
433
434 Ok(loss)
435 }
436
437 fn forward_pass(&self, params: &ModelParameters, inputs: &Array2<f32>) -> Result<Array2<f32>> {
439 let mut activations = inputs.clone();
440
441 if let Some(layer1_weights) = params.get_parameter("layer1_weight") {
443 activations = activations.dot(layer1_weights);
444
445 if let Some(layer1_bias) = params.get_parameter("layer1_bias") {
447 for mut row in activations.rows_mut() {
448 for (i, &bias) in layer1_bias.row(0).iter().enumerate() {
449 if i < row.len() {
450 row[i] += bias;
451 }
452 }
453 }
454 }
455
456 activations.mapv_inplace(|x| x.max(0.0));
458 }
459
460 if let Some(output_weights) = params.get_parameter("output_weight") {
462 activations = activations.dot(output_weights);
463
464 if let Some(output_bias) = params.get_parameter("output_bias") {
465 for mut row in activations.rows_mut() {
466 for (i, &bias) in output_bias.row(0).iter().enumerate() {
467 if i < row.len() {
468 row[i] += bias;
469 }
470 }
471 }
472 }
473 }
474
475 Ok(activations)
476 }
477
478 fn compute_mse_loss(&self, predictions: &Array2<f32>, targets: &Array2<f32>) -> Result<f32> {
480 let diff = predictions - targets;
481 let mse = diff.mapv(|x| x * x).mean().unwrap_or(0.0);
482 Ok(mse)
483 }
484
485 fn compute_cross_entropy_loss(
487 &self,
488 predictions: &Array2<f32>,
489 targets: &Array2<f32>,
490 ) -> Result<f32> {
491 let batch_size = predictions.nrows();
492 let mut total_loss = 0.0;
493
494 for i in 0..batch_size {
495 let pred_row = predictions.row(i);
496 let target_row = targets.row(i);
497
498 let max_pred = pred_row.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
500 let exp_preds: Vec<f32> = pred_row.iter().map(|&x| (x - max_pred).exp()).collect();
501 let sum_exp: f32 = exp_preds.iter().sum();
502 let softmax_preds: Vec<f32> = exp_preds.iter().map(|&x| x / sum_exp).collect();
503
504 let mut row_loss = 0.0;
506 for (&pred, &target) in softmax_preds.iter().zip(target_row.iter()) {
507 if target > 0.0 {
508 row_loss -= target * pred.max(1e-15).ln();
509 }
510 }
511 total_loss += row_loss;
512 }
513
514 Ok(total_loss / batch_size as f32)
515 }
516
517 fn compute_second_order_gradients(
519 &self,
520 meta_params: &ModelParameters,
521 adapted_params: &ModelParameters,
522 query_set: &TaskData,
523 ) -> Result<ModelParameters> {
524 let eps = 1e-4f32;
525 let mut second_order_grads = ModelParameters::new();
526
527 for (param_name, meta_param) in &meta_params.parameters {
529 if adapted_params.get_parameter(param_name).is_none() {
530 continue;
531 }
532
533 let mut param_grad = Array2::zeros(meta_param.raw_dim());
534
535 for ((i, j), _) in meta_param.indexed_iter() {
537 let mut meta_plus = meta_params.clone_parameters();
539 let mut meta_minus = meta_params.clone_parameters();
540
541 if let (Some(param_plus), Some(param_minus)) = (
542 meta_plus.get_parameter_mut(param_name),
543 meta_minus.get_parameter_mut(param_name),
544 ) {
545 param_plus[[i, j]] += eps;
546 param_minus[[i, j]] -= eps;
547
548 let grad_plus =
550 self.compute_meta_gradient_at_point(&meta_plus, adapted_params, query_set)?;
551 let grad_minus = self.compute_meta_gradient_at_point(
552 &meta_minus,
553 adapted_params,
554 query_set,
555 )?;
556
557 if let (Some(g_plus), Some(g_minus)) = (
559 grad_plus.get_parameter(param_name),
560 grad_minus.get_parameter(param_name),
561 ) {
562 param_grad[[i, j]] = (g_plus[[i, j]] - g_minus[[i, j]]) / (2.0 * eps);
563 }
564 }
565 }
566
567 second_order_grads.add_parameter(param_name.clone(), param_grad);
568 }
569
570 Ok(second_order_grads)
571 }
572
573 fn compute_meta_gradient_at_point(
575 &self,
576 meta_params: &ModelParameters,
577 adapted_params: &ModelParameters,
578 query_set: &TaskData,
579 ) -> Result<ModelParameters> {
580 let query_loss_grad = self.compute_query_loss_gradients(adapted_params, query_set)?;
582
583 let jacobian = self.compute_adaptation_jacobian(meta_params, adapted_params)?;
585
586 let mut meta_grad = ModelParameters::new();
588 for (param_name, loss_grad) in &query_loss_grad.parameters {
589 if let Some(jac) = jacobian.get_parameter(param_name) {
590 let meta_gradient = loss_grad * jac;
592 meta_grad.add_parameter(param_name.clone(), meta_gradient);
593 }
594 }
595
596 Ok(meta_grad)
597 }
598
599 fn compute_query_loss_gradients(
601 &self,
602 params: &ModelParameters,
603 query_set: &TaskData,
604 ) -> Result<ModelParameters> {
605 let mut gradients = ModelParameters::new();
606
607 for (param_name, param) in ¶ms.parameters {
608 let grad = self.compute_parameter_gradient_for_query(param, query_set)?;
609 gradients.add_parameter(param_name.clone(), grad);
610 }
611
612 Ok(gradients)
613 }
614
615 fn compute_parameter_gradient_for_query(
617 &self,
618 param: &Array2<f32>,
619 query_set: &TaskData,
620 ) -> Result<Array2<f32>> {
621 let eps = 1e-5f32;
622 let mut gradients = Array2::zeros(param.raw_dim());
623
624 for ((i, j), param_val) in param.indexed_iter() {
625 let mut param_plus = param.clone();
627 param_plus[[i, j]] = param_val + eps;
628
629 let mut param_minus = param.clone();
630 param_minus[[i, j]] = param_val - eps;
631
632 let mut params_plus = ModelParameters::new();
634 let mut params_minus = ModelParameters::new();
635 params_plus.add_parameter("temp_param".to_string(), param_plus);
636 params_minus.add_parameter("temp_param".to_string(), param_minus);
637
638 let loss_plus = self.compute_query_loss(¶ms_plus, query_set)?;
639 let loss_minus = self.compute_query_loss(¶ms_minus, query_set)?;
640
641 gradients[[i, j]] = (loss_plus - loss_minus) / (2.0 * eps);
642 }
643
644 Ok(gradients)
645 }
646
647 fn compute_adaptation_jacobian(
649 &self,
650 meta_params: &ModelParameters,
651 adapted_params: &ModelParameters,
652 ) -> Result<ModelParameters> {
653 let mut jacobian = ModelParameters::new();
654
655 for (param_name, meta_param) in &meta_params.parameters {
658 if adapted_params.get_parameter(param_name).is_some() {
659 let identity_jac =
661 Array2::eye(meta_param.len()).into_shape_with_order(meta_param.raw_dim())?;
662 jacobian.add_parameter(param_name.clone(), identity_jac);
663 }
664 }
665
666 Ok(jacobian)
667 }
668
669 fn clip_gradients(&self, gradients: &mut ModelParameters) -> f32 {
671 let mut total_norm = 0.0;
672
673 for grad in gradients.parameters.values() {
675 total_norm += grad.mapv(|x| x * x).sum();
676 }
677 total_norm = total_norm.sqrt();
678
679 if total_norm > self.config.grad_clip {
681 let clip_coef = self.config.grad_clip / total_norm;
682 for grad in gradients.parameters.values_mut() {
683 *grad = grad.clone() * clip_coef;
684 }
685 }
686
687 total_norm
688 }
689}
690
691impl MetaLearningAlgorithm for MAMLTrainer {
692 fn meta_update(&mut self, task_batch: &TaskBatch) -> Result<MetaUpdateResult> {
693 let (mut meta_gradients, meta_loss) = self.compute_meta_gradients(task_batch)?;
694 let grad_norm = self.clip_gradients(&mut meta_gradients);
695
696 {
698 let mut meta_params =
699 self.meta_parameters.write().expect("lock should not be poisoned");
700 meta_params.update_with_gradients(&meta_gradients, self.config.meta_lr)?;
701 }
702
703 self.meta_step += 1;
704
705 Ok(MetaUpdateResult {
706 meta_loss,
707 task_losses: vec![meta_loss; task_batch.num_tasks()], grad_norm,
709 updated_parameters: self
710 .meta_parameters
711 .read()
712 .expect("lock should not be poisoned")
713 .clone_parameters(),
714 })
715 }
716
717 fn adapt(&self, support_set: &TaskData, adaptation_steps: usize) -> Result<ModelParameters> {
718 let meta_params = self.meta_parameters.read().expect("lock should not be poisoned");
719 let mut adapted_params = meta_params.clone_parameters();
720
721 for _ in 0..adaptation_steps {
722 let gradients = self.compute_inner_gradients(&adapted_params, support_set)?;
723 adapted_params.update_with_gradients(&gradients, self.config.inner_lr)?;
724 }
725
726 Ok(adapted_params)
727 }
728
729 fn evaluate(&self, params: &ModelParameters, query_set: &TaskData) -> Result<f32> {
730 self.compute_query_loss(params, query_set)
731 }
732}
733
734pub struct ReptileTrainer {
736 config: ReptileConfig,
737 meta_parameters: Arc<RwLock<ModelParameters>>,
738 meta_step: usize,
739}
740
741impl ReptileTrainer {
742 pub fn new(config: ReptileConfig, initial_parameters: ModelParameters) -> Self {
743 Self {
744 config,
745 meta_parameters: Arc::new(RwLock::new(initial_parameters)),
746 meta_step: 0,
747 }
748 }
749
750 fn sgd_on_task(&self, task_data: &TaskData) -> Result<ModelParameters> {
752 let meta_params = self.meta_parameters.read().expect("lock should not be poisoned");
753 let mut task_params = meta_params.clone_parameters();
754
755 for _ in 0..self.config.adaptation_steps {
756 let gradients = self.compute_task_gradients(&task_params, task_data)?;
757 task_params.update_with_gradients(&gradients, self.config.inner_lr)?;
758 }
759
760 Ok(task_params)
761 }
762
763 fn compute_task_gradients(
765 &self,
766 params: &ModelParameters,
767 data: &TaskData,
768 ) -> Result<ModelParameters> {
769 let mut gradients = ModelParameters::new();
770
771 for (param_name, param) in ¶ms.parameters {
773 let grad = self.compute_task_parameter_gradient(param, data, param_name)?;
774 gradients.add_parameter(param_name.clone(), grad);
775 }
776
777 Ok(gradients)
778 }
779
780 fn compute_task_parameter_gradient(
782 &self,
783 param: &Array2<f32>,
784 data: &TaskData,
785 param_name: &str,
786 ) -> Result<Array2<f32>> {
787 let eps = 1e-5f32;
788 let mut gradients = Array2::zeros(param.raw_dim());
789
790 for ((i, j), param_val) in param.indexed_iter() {
791 let mut param_plus = param.clone();
793 let mut param_minus = param.clone();
794 param_plus[[i, j]] = param_val + eps;
795 param_minus[[i, j]] = param_val - eps;
796
797 let loss_plus = self.compute_task_loss_for_param(¶m_plus, data, param_name)?;
799 let loss_minus = self.compute_task_loss_for_param(¶m_minus, data, param_name)?;
800
801 gradients[[i, j]] = (loss_plus - loss_minus) / (2.0 * eps);
803 }
804
805 Ok(gradients)
806 }
807
808 fn compute_task_loss_for_param(
810 &self,
811 param: &Array2<f32>,
812 data: &TaskData,
813 param_name: &str,
814 ) -> Result<f32> {
815 let mut temp_params = ModelParameters::new();
817 temp_params.add_parameter(param_name.to_string(), param.clone());
818
819 let predictions = if param_name.contains("weight") {
821 if param.ncols() == data.inputs.ncols() {
823 data.inputs.dot(param)
824 } else if param.nrows() == data.inputs.ncols() {
825 data.inputs.dot(¶m.t())
826 } else {
827 data.inputs.clone()
829 }
830 } else if param_name.contains("bias") {
831 let mut result = data.inputs.clone();
833 for mut row in result.rows_mut() {
834 for (k, &bias) in param.row(0).iter().enumerate() {
835 if k < row.len() {
836 row[k] += bias;
837 }
838 }
839 }
840 result
841 } else {
842 data.inputs.clone()
844 };
845
846 let diff = &predictions - &data.targets;
848 let mse = diff.mapv(|x| x * x).mean().unwrap_or(0.0);
849
850 Ok(mse)
851 }
852}
853
854impl MetaLearningAlgorithm for ReptileTrainer {
855 fn meta_update(&mut self, task_batch: &TaskBatch) -> Result<MetaUpdateResult> {
856 let mut total_update = ModelParameters::new();
857 let mut total_loss = 0.0;
858
859 {
861 let meta_params = self.meta_parameters.read().expect("lock should not be poisoned");
862 for (name, param) in &meta_params.parameters {
863 total_update.add_parameter(name.clone(), Array2::zeros(param.raw_dim()));
864 }
865 }
866
867 for (support_set, query_set) in task_batch.support_sets.iter().zip(&task_batch.query_sets) {
869 let task_params = self.sgd_on_task(support_set)?;
871
872 let meta_params = self.meta_parameters.read().expect("lock should not be poisoned");
874 let update = task_params.subtract(&meta_params)?;
875
876 for (name, param_update) in &update.parameters {
878 if let Some(total_param_update) = total_update.get_parameter_mut(name) {
879 *total_param_update = total_param_update.clone() + param_update;
880 }
881 }
882
883 let loss = self.evaluate(&task_params, query_set)?;
885 total_loss += loss;
886 }
887
888 let num_tasks = task_batch.num_tasks() as f32;
890 for param_update in total_update.parameters.values_mut() {
891 *param_update = param_update.clone() / num_tasks;
892 }
893 total_loss /= num_tasks;
894
895 {
897 let mut meta_params =
898 self.meta_parameters.write().expect("lock should not be poisoned");
899 let scaled_update = total_update.scale(self.config.meta_lr);
900 *meta_params = meta_params.add(&scaled_update)?;
901 }
902
903 self.meta_step += 1;
904
905 Ok(MetaUpdateResult {
906 meta_loss: total_loss,
907 task_losses: vec![total_loss; task_batch.num_tasks()], grad_norm: 0.0, updated_parameters: self
910 .meta_parameters
911 .read()
912 .expect("lock should not be poisoned")
913 .clone_parameters(),
914 })
915 }
916
917 fn adapt(&self, support_set: &TaskData, adaptation_steps: usize) -> Result<ModelParameters> {
918 let meta_params = self.meta_parameters.read().expect("lock should not be poisoned");
919 let mut adapted_params = meta_params.clone_parameters();
920
921 for _ in 0..adaptation_steps {
922 let gradients = self.compute_task_gradients(&adapted_params, support_set)?;
923 adapted_params.update_with_gradients(&gradients, self.config.inner_lr)?;
924 }
925
926 Ok(adapted_params)
927 }
928
929 fn evaluate(&self, params: &ModelParameters, query_set: &TaskData) -> Result<f32> {
930 let mut predictions = query_set.inputs.clone();
932
933 for (param_name, param) in ¶ms.parameters {
935 if param_name.contains("weight") {
936 if param.ncols() == predictions.ncols() {
938 predictions = predictions.dot(param);
939 } else if param.nrows() == predictions.ncols() {
940 predictions = predictions.dot(¶m.t());
941 }
942
943 if !param_name.contains("output") {
945 predictions.mapv_inplace(|x| x.max(0.0));
946 }
947 } else if param_name.contains("bias") {
948 for mut row in predictions.rows_mut() {
950 for (k, &bias) in param.row(0).iter().enumerate() {
951 if k < row.len() {
952 row[k] += bias;
953 }
954 }
955 }
956 }
957 }
958
959 let diff = &predictions - &query_set.targets;
961 let mse = diff.mapv(|x| x * x).mean().unwrap_or(0.0);
962
963 Ok(mse)
964 }
965}
966
967#[cfg(test)]
968mod tests {
969 use super::*;
970
971 #[test]
972 fn test_model_parameters() {
973 let mut params = ModelParameters::new();
974 let tensor = Array2::ones((2, 3));
975 params.add_parameter("layer1".to_string(), tensor.clone());
976
977 assert_eq!(
978 params.get_parameter("layer1").expect("tensor operation failed"),
979 &tensor
980 );
981 assert_eq!(
982 params.shapes.get("layer1").expect("expected value not found"),
983 &vec![2, 3]
984 );
985 }
986
987 #[test]
988 fn test_task_data() {
989 let inputs = Array2::ones((10, 5));
990 let targets = Array2::zeros((10, 2));
991 let task_data = TaskData::new(inputs, targets, "test_task".to_string());
992
993 assert_eq!(task_data.batch_size(), 10);
994 assert_eq!(task_data.task_id, "test_task");
995 }
996
997 #[test]
998 fn test_task_batch() {
999 let support = vec![
1000 TaskData::new(
1001 Array2::ones((5, 3)),
1002 Array2::zeros((5, 1)),
1003 "task1".to_string(),
1004 ),
1005 TaskData::new(
1006 Array2::ones((5, 3)),
1007 Array2::zeros((5, 1)),
1008 "task2".to_string(),
1009 ),
1010 ];
1011 let query = vec![
1012 TaskData::new(
1013 Array2::ones((3, 3)),
1014 Array2::zeros((3, 1)),
1015 "task1".to_string(),
1016 ),
1017 TaskData::new(
1018 Array2::ones((3, 3)),
1019 Array2::zeros((3, 1)),
1020 "task2".to_string(),
1021 ),
1022 ];
1023
1024 let batch = TaskBatch::new(support, query).expect("operation failed in test");
1025 assert_eq!(batch.num_tasks(), 2);
1026 }
1027
1028 #[test]
1029 fn test_maml_trainer_creation() {
1030 let config = MAMLConfig::default();
1031 let mut params = ModelParameters::new();
1032 params.add_parameter("test".to_string(), Array2::<f32>::ones((2, 2)));
1033
1034 let trainer = MAMLTrainer::new(config, params);
1035 assert_eq!(trainer.meta_step, 0);
1036 }
1037
1038 #[test]
1039 fn test_reptile_trainer_creation() {
1040 let config = ReptileConfig::default();
1041 let mut params = ModelParameters::new();
1042 params.add_parameter("test".to_string(), Array2::<f32>::ones((2, 2)));
1043
1044 let trainer = ReptileTrainer::new(config, params);
1045 assert_eq!(trainer.meta_step, 0);
1046 }
1047
1048 #[test]
1049 fn test_parameter_operations() {
1050 let mut params1 = ModelParameters::new();
1051 let mut params2 = ModelParameters::new();
1052
1053 params1.add_parameter("layer1".to_string(), Array2::<f32>::ones((2, 2)));
1054 params2.add_parameter("layer1".to_string(), Array2::<f32>::ones((2, 2)) * 2.0);
1055
1056 let diff = params2.subtract(¶ms1).expect("operation failed in test");
1057 let sum = params1.add(¶ms2).expect("add operation failed");
1058 let scaled = params1.scale(2.0);
1059
1060 assert_eq!(
1061 diff.get_parameter("layer1").expect("operation failed in test"),
1062 &Array2::<f32>::ones((2, 2))
1063 );
1064 assert_eq!(
1065 sum.get_parameter("layer1").expect("operation failed in test"),
1066 &(Array2::<f32>::ones((2, 2)) * 3.0)
1067 );
1068 assert_eq!(
1069 scaled.get_parameter("layer1").expect("operation failed in test"),
1070 &(Array2::<f32>::ones((2, 2)) * 2.0)
1071 );
1072 }
1073}