1use serde::{Deserialize, Serialize};
37use std::collections::HashMap;
38use trustformers_core::{
39 errors::{invalid_input, unsupported_operation, TrustformersError},
40 tensor::Tensor,
41};
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct MetaLearningConfig {
46 pub algorithm: MetaAlgorithm,
48 pub inner_lr: f64,
50 pub meta_lr: f64,
52 pub inner_steps: usize,
54 pub support_size: usize,
56 pub query_size: usize,
58 pub num_ways: usize,
60 pub num_shots: usize,
62 pub first_order: bool,
64 pub temperature: f64,
66 pub embedding_dim: usize,
68 pub normalize_embeddings: bool,
70 pub memory_size: usize,
72 pub memory_key_dim: usize,
74 pub memory_value_dim: usize,
76 pub meta_batch_size: usize,
78 pub task_specific_params: bool,
80 pub inner_l2_reg: f64,
82 pub grad_clip_norm: f64,
84}
85
86impl Default for MetaLearningConfig {
87 fn default() -> Self {
88 Self {
89 algorithm: MetaAlgorithm::MAML,
90 inner_lr: 0.01,
91 meta_lr: 0.001,
92 inner_steps: 5,
93 support_size: 5,
94 query_size: 15,
95 num_ways: 5,
96 num_shots: 1,
97 first_order: false,
98 temperature: 1.0,
99 embedding_dim: 512,
100 normalize_embeddings: true,
101 memory_size: 128,
102 memory_key_dim: 64,
103 memory_value_dim: 256,
104 meta_batch_size: 32,
105 task_specific_params: false,
106 inner_l2_reg: 0.0001,
107 grad_clip_norm: 10.0,
108 }
109 }
110}
111
112#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
114pub enum MetaAlgorithm {
115 MAML,
117 Reptile,
119 ProtoNet,
121 MatchingNet,
123 RelationNet,
125 MANN,
127 GBML,
129 MetaSGD,
131 L2L,
133}
134
135pub struct MetaLearner {
137 config: MetaLearningConfig,
138 model: Box<dyn MetaLearningModel>,
139 optimizer: Box<dyn MetaOptimizer>,
140 task_sampler: TaskSampler,
141 meta_statistics: MetaStatistics,
142 episode_history: Vec<EpisodeResult>,
143 current_episode: usize,
144}
145
146impl MetaLearner {
147 pub fn new(config: MetaLearningConfig) -> Result<Self, TrustformersError> {
149 let model = Self::create_model(&config)?;
150 let optimizer = Self::create_optimizer(&config)?;
151 let task_sampler = TaskSampler::new(&config)?;
152
153 Ok(Self {
154 config,
155 model,
156 optimizer,
157 task_sampler,
158 meta_statistics: MetaStatistics::new(),
159 episode_history: Vec::new(),
160 current_episode: 0,
161 })
162 }
163
164 fn create_model(
166 config: &MetaLearningConfig,
167 ) -> Result<Box<dyn MetaLearningModel>, TrustformersError> {
168 match config.algorithm {
169 MetaAlgorithm::MAML => Ok(Box::new(MAMLModel::new(config)?)),
170 MetaAlgorithm::Reptile => Ok(Box::new(ReptileModel::new(config)?)),
171 MetaAlgorithm::ProtoNet => Ok(Box::new(PrototypicalModel::new(config)?)),
172 MetaAlgorithm::MatchingNet => Ok(Box::new(MatchingNetModel::new(config)?)),
173 MetaAlgorithm::RelationNet => Ok(Box::new(RelationNetModel::new(config)?)),
174 MetaAlgorithm::MANN => Ok(Box::new(MemoryAugmentedModel::new(config)?)),
175 MetaAlgorithm::GBML => Ok(Box::new(GradientBasedModel::new(config)?)),
176 MetaAlgorithm::MetaSGD => Ok(Box::new(MetaSGDModel::new(config)?)),
177 MetaAlgorithm::L2L => Ok(Box::new(L2LModel::new(config)?)),
178 }
179 }
180
181 fn create_optimizer(
183 config: &MetaLearningConfig,
184 ) -> Result<Box<dyn MetaOptimizer>, TrustformersError> {
185 match config.algorithm {
186 MetaAlgorithm::MAML | MetaAlgorithm::Reptile | MetaAlgorithm::GBML => {
187 Ok(Box::new(SGDMetaOptimizer::new(config.meta_lr)?))
188 },
189 MetaAlgorithm::MetaSGD => Ok(Box::new(LearnedLROptimizer::new(config.meta_lr)?)),
190 _ => Ok(Box::new(AdamMetaOptimizer::new(config.meta_lr)?)),
191 }
192 }
193
194 pub fn train_episode(
196 &mut self,
197 task_batch: TaskBatch,
198 ) -> Result<EpisodeResult, TrustformersError> {
199 let start_time = std::time::Instant::now();
200 let mut total_loss = 0.0;
201 let mut total_accuracy = 0.0;
202 let num_tasks = task_batch.tasks.len();
203
204 for task in &task_batch.tasks {
205 let task_result = self.train_single_task(task)?;
206 total_loss += task_result.query_loss;
207 total_accuracy += task_result.query_accuracy;
208 }
209
210 self.optimizer.step(&mut *self.model)?;
212
213 let episode_result = EpisodeResult {
214 episode: self.current_episode,
215 meta_loss: total_loss / num_tasks as f64,
216 meta_accuracy: total_accuracy / num_tasks as f64,
217 num_tasks,
218 episode_time: start_time.elapsed(),
219 algorithm: self.config.algorithm,
220 };
221
222 self.episode_history.push(episode_result.clone());
223 self.meta_statistics.update(&episode_result);
224 self.current_episode += 1;
225
226 Ok(episode_result)
227 }
228
229 fn train_single_task(&mut self, task: &Task) -> Result<TaskResult, TrustformersError> {
231 match self.config.algorithm {
232 MetaAlgorithm::MAML => self.train_maml_task(task),
233 MetaAlgorithm::Reptile => self.train_reptile_task(task),
234 MetaAlgorithm::ProtoNet => self.train_prototypical_task(task),
235 MetaAlgorithm::MatchingNet => self.train_matching_task(task),
236 MetaAlgorithm::RelationNet => self.train_relation_task(task),
237 MetaAlgorithm::MANN => self.train_memory_task(task),
238 MetaAlgorithm::GBML => self.train_gradient_based_task(task),
239 MetaAlgorithm::MetaSGD => self.train_meta_sgd_task(task),
240 MetaAlgorithm::L2L => self.train_l2l_task(task),
241 }
242 }
243
244 fn train_maml_task(&mut self, task: &Task) -> Result<TaskResult, TrustformersError> {
246 let initial_params = self.model.get_parameters()?;
248
249 for _ in 0..self.config.inner_steps {
251 let support_loss = self.model.forward(&task.support_set)?;
252 let gradients = self.model.compute_gradients(support_loss)?;
253
254 self.model.apply_gradients(&gradients, self.config.inner_lr)?;
256 }
257
258 let query_loss = self.model.forward(&task.query_set)?;
260 let query_accuracy = self.model.compute_accuracy(&task.query_set)?;
261
262 let meta_gradients = if self.config.first_order {
264 self.model.compute_first_order_gradients(query_loss)?
266 } else {
267 self.model.compute_second_order_gradients(&initial_params, query_loss)?
269 };
270
271 self.optimizer.accumulate_gradients(meta_gradients)?;
273
274 self.model.set_parameters(initial_params)?;
276
277 Ok(TaskResult {
278 support_loss: 0.0, query_loss,
280 query_accuracy,
281 adaptation_time: std::time::Duration::from_millis(0),
282 })
283 }
284
285 fn train_reptile_task(&mut self, task: &Task) -> Result<TaskResult, TrustformersError> {
287 let initial_params = self.model.get_parameters()?;
288
289 for _ in 0..self.config.inner_steps {
291 let support_loss = self.model.forward(&task.support_set)?;
292 let gradients = self.model.compute_gradients(support_loss)?;
293 self.model.apply_gradients(&gradients, self.config.inner_lr)?;
294 }
295
296 let adapted_params = self.model.get_parameters()?;
297 let query_loss = self.model.forward(&task.query_set)?;
298 let query_accuracy = self.model.compute_accuracy(&task.query_set)?;
299
300 let meta_gradients = self.compute_param_difference(&initial_params, &adapted_params)?;
302 self.optimizer.accumulate_gradients(meta_gradients)?;
303
304 self.model.set_parameters(initial_params)?;
306
307 Ok(TaskResult {
308 support_loss: 0.0,
309 query_loss,
310 query_accuracy,
311 adaptation_time: std::time::Duration::from_millis(0),
312 })
313 }
314
315 fn train_prototypical_task(&mut self, task: &Task) -> Result<TaskResult, TrustformersError> {
317 let prototypes = self.compute_prototypes(&task.support_set)?;
319
320 let query_loss = self.compute_prototypical_loss(&task.query_set, &prototypes)?;
322 let query_accuracy = self.compute_prototypical_accuracy(&task.query_set, &prototypes)?;
323
324 let gradients = self.model.compute_gradients(query_loss)?;
326 self.optimizer.accumulate_gradients(gradients)?;
327
328 Ok(TaskResult {
329 support_loss: 0.0,
330 query_loss,
331 query_accuracy,
332 adaptation_time: std::time::Duration::from_millis(0),
333 })
334 }
335
336 fn train_matching_task(&mut self, task: &Task) -> Result<TaskResult, TrustformersError> {
338 let attention_weights =
340 self.compute_attention_weights(&task.query_set, &task.support_set)?;
341
342 let predictions =
344 self.compute_matching_predictions(&attention_weights, &task.support_set)?;
345
346 let query_loss = self.compute_matching_loss(&predictions, &task.query_set)?;
347 let query_accuracy = self.compute_matching_accuracy(&predictions, &task.query_set)?;
348
349 let gradients = self.model.compute_gradients(query_loss)?;
350 self.optimizer.accumulate_gradients(gradients)?;
351
352 Ok(TaskResult {
353 support_loss: 0.0,
354 query_loss,
355 query_accuracy,
356 adaptation_time: std::time::Duration::from_millis(0),
357 })
358 }
359
360 fn train_relation_task(&mut self, task: &Task) -> Result<TaskResult, TrustformersError> {
362 let mut total_loss = 0.0;
363 let mut correct_predictions = 0;
364 let mut total_predictions = 0;
365
366 for query_example in &task.query_set.examples {
368 let query_embedding = self.model.embed(query_example)?;
369 let mut relation_scores = Vec::new();
370
371 for support_example in &task.support_set.examples {
372 let support_embedding = self.model.embed(support_example)?;
373 let relation_score =
374 self.model.compute_relation(&query_embedding, &support_embedding)?;
375 relation_scores.push(relation_score);
376 }
377
378 let loss =
380 self.compute_relation_loss(&relation_scores, query_example, &task.support_set)?;
381 total_loss += loss;
382
383 if self.is_correct_prediction(&relation_scores, query_example, &task.support_set)? {
384 correct_predictions += 1;
385 }
386 total_predictions += 1;
387 }
388
389 let query_loss = total_loss / total_predictions as f64;
390 let query_accuracy = correct_predictions as f64 / total_predictions as f64;
391
392 let gradients = self.model.compute_gradients(query_loss)?;
393 self.optimizer.accumulate_gradients(gradients)?;
394
395 Ok(TaskResult {
396 support_loss: 0.0,
397 query_loss,
398 query_accuracy,
399 adaptation_time: std::time::Duration::from_millis(0),
400 })
401 }
402
403 fn train_memory_task(&mut self, task: &Task) -> Result<TaskResult, TrustformersError> {
405 for example in &task.support_set.examples {
407 self.model.write_to_memory(example)?;
408 }
409
410 let mut total_loss = 0.0;
411 let mut correct_predictions = 0;
412 let total_predictions = task.query_set.examples.len();
413
414 for query_example in &task.query_set.examples {
416 let memory_output = self.model.read_from_memory(query_example)?;
417 let prediction = self.model.predict_from_memory(&memory_output)?;
418
419 let loss = self.compute_memory_loss(&prediction, query_example)?;
420 total_loss += loss;
421
422 if self.is_memory_prediction_correct(&prediction, query_example)? {
423 correct_predictions += 1;
424 }
425 }
426
427 let query_loss = total_loss / total_predictions as f64;
428 let query_accuracy = correct_predictions as f64 / total_predictions as f64;
429
430 let gradients = self.model.compute_gradients(query_loss)?;
431 self.optimizer.accumulate_gradients(gradients)?;
432
433 self.model.clear_memory()?;
435
436 Ok(TaskResult {
437 support_loss: 0.0,
438 query_loss,
439 query_accuracy,
440 adaptation_time: std::time::Duration::from_millis(0),
441 })
442 }
443
444 fn train_gradient_based_task(&mut self, task: &Task) -> Result<TaskResult, TrustformersError> {
446 let meta_learner_state = self.model.get_meta_learner_state()?;
448
449 let adapted_params =
451 self.model.apply_learned_algorithm(&task.support_set, &meta_learner_state)?;
452
453 let query_loss = self.model.evaluate_with_params(&task.query_set, &adapted_params)?;
455 let query_accuracy =
456 self.model.compute_accuracy_with_params(&task.query_set, &adapted_params)?;
457
458 let gradients = self.model.compute_meta_learner_gradients(query_loss)?;
460 self.optimizer.accumulate_gradients(gradients)?;
461
462 Ok(TaskResult {
463 support_loss: 0.0,
464 query_loss,
465 query_accuracy,
466 adaptation_time: std::time::Duration::from_millis(0),
467 })
468 }
469
470 fn train_meta_sgd_task(&mut self, task: &Task) -> Result<TaskResult, TrustformersError> {
472 let initial_params = self.model.get_parameters()?;
473 let learning_rates = self.model.get_learning_rates()?;
474
475 for _ in 0..self.config.inner_steps {
477 let support_loss = self.model.forward(&task.support_set)?;
478 let gradients = self.model.compute_gradients(support_loss)?;
479
480 self.model.apply_gradients_with_lr(&gradients, &learning_rates)?;
482 }
483
484 let query_loss = self.model.forward(&task.query_set)?;
485 let query_accuracy = self.model.compute_accuracy(&task.query_set)?;
486
487 let param_gradients =
489 self.model.compute_second_order_gradients(&initial_params, query_loss)?;
490 let lr_gradients = self.model.compute_lr_gradients(query_loss)?;
491
492 self.optimizer.accumulate_param_gradients(param_gradients)?;
493 self.optimizer.accumulate_lr_gradients(lr_gradients)?;
494
495 self.model.set_parameters(initial_params)?;
497
498 Ok(TaskResult {
499 support_loss: 0.0,
500 query_loss,
501 query_accuracy,
502 adaptation_time: std::time::Duration::from_millis(0),
503 })
504 }
505
506 fn train_l2l_task(&mut self, task: &Task) -> Result<TaskResult, TrustformersError> {
508 let mut lstm_state = self.model.get_lstm_state()?;
510 let initial_params = self.model.get_parameters()?;
511
512 for step in 0..self.config.inner_steps {
514 let support_loss = self.model.forward(&task.support_set)?;
515 let gradients = self.model.compute_gradients(support_loss)?;
516
517 let (updates, new_lstm_state) =
519 self.model.lstm_update(&gradients, &lstm_state, step)?;
520 lstm_state = new_lstm_state;
521
522 self.model.apply_lstm_updates(&updates)?;
524 }
525
526 let query_loss = self.model.forward(&task.query_set)?;
527 let query_accuracy = self.model.compute_accuracy(&task.query_set)?;
528
529 let lstm_gradients = self.model.compute_lstm_gradients(query_loss)?;
531 self.optimizer.accumulate_gradients(lstm_gradients)?;
532
533 self.model.set_parameters(initial_params)?;
535
536 Ok(TaskResult {
537 support_loss: 0.0,
538 query_loss,
539 query_accuracy,
540 adaptation_time: std::time::Duration::from_millis(0),
541 })
542 }
543
544 fn compute_param_difference(
546 &self,
547 params1: &ModelParameters,
548 params2: &ModelParameters,
549 ) -> Result<ModelGradients, TrustformersError> {
550 let mut gradients = ModelGradients::new();
552
553 for (name, param1) in ¶ms1.parameters {
554 if let Some(param2) = params2.parameters.get(name) {
555 let diff = param2.sub(param1)?; gradients.gradients.insert(name.clone(), diff);
557 }
558 }
559
560 Ok(gradients)
561 }
562
563 fn compute_prototypes(
564 &self,
565 support_set: &ExampleSet,
566 ) -> Result<Vec<Tensor>, TrustformersError> {
567 let mut prototypes = Vec::new();
568 let num_classes = self.config.num_ways;
569
570 for class_id in 0..num_classes {
571 let mut class_embeddings = Vec::new();
572
573 for example in &support_set.examples {
575 if example.label == class_id {
576 let embedding = self.model.embed(example)?;
577 class_embeddings.push(embedding);
578 }
579 }
580
581 if !class_embeddings.is_empty() {
583 let prototype = self.compute_mean_embedding(&class_embeddings)?;
584 prototypes.push(prototype);
585 }
586 }
587
588 Ok(prototypes)
589 }
590
591 fn compute_mean_embedding(&self, embeddings: &[Tensor]) -> Result<Tensor, TrustformersError> {
592 if embeddings.is_empty() {
593 return Err(invalid_input("Empty embeddings list"));
594 }
595
596 let mut sum = embeddings[0].clone();
597 for embedding in &embeddings[1..] {
598 sum = sum.add(embedding)?;
599 }
600
601 sum.scalar_div(embeddings.len() as f32)
602 }
603
604 fn compute_prototypical_loss(
605 &self,
606 query_set: &ExampleSet,
607 prototypes: &[Tensor],
608 ) -> Result<f64, TrustformersError> {
609 let mut total_loss = 0.0;
610
611 for example in &query_set.examples {
612 let query_embedding = self.model.embed(example)?;
613 let distances = self.compute_distances(&query_embedding, prototypes)?;
614 let log_probs = self.compute_log_softmax(&distances, self.config.temperature)?;
615
616 total_loss -= log_probs[example.label];
618 }
619
620 Ok(total_loss / query_set.examples.len() as f64)
621 }
622
623 fn compute_prototypical_accuracy(
624 &self,
625 query_set: &ExampleSet,
626 prototypes: &[Tensor],
627 ) -> Result<f64, TrustformersError> {
628 let mut correct = 0;
629
630 for example in &query_set.examples {
631 let query_embedding = self.model.embed(example)?;
632 let distances = self.compute_distances(&query_embedding, prototypes)?;
633
634 let predicted_class = distances
636 .iter()
637 .enumerate()
638 .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
639 .map(|(i, _)| i)
640 .unwrap_or(0);
641
642 if predicted_class == example.label {
643 correct += 1;
644 }
645 }
646
647 Ok(correct as f64 / query_set.examples.len() as f64)
648 }
649
650 fn compute_distances(
651 &self,
652 query: &Tensor,
653 prototypes: &[Tensor],
654 ) -> Result<Vec<f64>, TrustformersError> {
655 let mut distances = Vec::new();
656
657 for prototype in prototypes {
658 let diff = query.sub(prototype)?;
659 let distance = diff.norm()? as f64;
660 distances.push(distance);
661 }
662
663 Ok(distances)
664 }
665
666 fn compute_log_softmax(
667 &self,
668 distances: &[f64],
669 temperature: f64,
670 ) -> Result<Vec<f64>, TrustformersError> {
671 let neg_distances: Vec<f64> = distances.iter().map(|d| -d / temperature).collect();
673
674 let max_val = neg_distances.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
676 let exp_sum: f64 = neg_distances.iter().map(|x| (x - max_val).exp()).sum();
677 let log_sum = max_val + exp_sum.ln();
678
679 Ok(neg_distances.iter().map(|x| x - log_sum).collect())
680 }
681
682 fn compute_attention_weights(
684 &self,
685 _query_set: &ExampleSet,
686 _support_set: &ExampleSet,
687 ) -> Result<Vec<Vec<f64>>, TrustformersError> {
688 Ok(vec![vec![1.0]])
690 }
691
692 fn compute_matching_predictions(
693 &self,
694 _weights: &[Vec<f64>],
695 _support_set: &ExampleSet,
696 ) -> Result<Vec<Vec<f64>>, TrustformersError> {
697 Ok(vec![vec![1.0]])
698 }
699
700 fn compute_matching_loss(
701 &self,
702 _predictions: &[Vec<f64>],
703 _query_set: &ExampleSet,
704 ) -> Result<f64, TrustformersError> {
705 Ok(1.0)
706 }
707
708 fn compute_matching_accuracy(
709 &self,
710 _predictions: &[Vec<f64>],
711 _query_set: &ExampleSet,
712 ) -> Result<f64, TrustformersError> {
713 Ok(0.8)
714 }
715
716 fn compute_relation_loss(
717 &self,
718 _scores: &[f64],
719 _example: &Example,
720 _support_set: &ExampleSet,
721 ) -> Result<f64, TrustformersError> {
722 Ok(1.0)
723 }
724
725 fn is_correct_prediction(
726 &self,
727 _scores: &[f64],
728 _example: &Example,
729 _support_set: &ExampleSet,
730 ) -> Result<bool, TrustformersError> {
731 Ok(true)
732 }
733
734 fn compute_memory_loss(
735 &self,
736 _prediction: &MemoryPrediction,
737 _example: &Example,
738 ) -> Result<f64, TrustformersError> {
739 Ok(1.0)
740 }
741
742 fn is_memory_prediction_correct(
743 &self,
744 _prediction: &MemoryPrediction,
745 _example: &Example,
746 ) -> Result<bool, TrustformersError> {
747 Ok(true)
748 }
749
750 pub fn evaluate(
752 &mut self,
753 task_batch: TaskBatch,
754 ) -> Result<EvaluationResult, TrustformersError> {
755 let mut total_accuracy = 0.0;
756 let mut task_results = Vec::new();
757
758 for task in &task_batch.tasks {
759 let task_result = self.evaluate_single_task(task)?;
760 total_accuracy += task_result.query_accuracy;
761 task_results.push(task_result);
762 }
763
764 Ok(EvaluationResult {
765 average_accuracy: total_accuracy / task_batch.tasks.len() as f64,
766 task_results,
767 num_tasks: task_batch.tasks.len(),
768 })
769 }
770
771 fn evaluate_single_task(&mut self, task: &Task) -> Result<TaskResult, TrustformersError> {
772 match self.config.algorithm {
774 MetaAlgorithm::MAML | MetaAlgorithm::Reptile => {
775 let initial_params = self.model.get_parameters()?;
776
777 for _ in 0..self.config.inner_steps {
779 let support_loss = self.model.forward(&task.support_set)?;
780 let gradients = self.model.compute_gradients(support_loss)?;
781 self.model.apply_gradients(&gradients, self.config.inner_lr)?;
782 }
783
784 let query_loss = self.model.forward(&task.query_set)?;
786 let query_accuracy = self.model.compute_accuracy(&task.query_set)?;
787
788 self.model.set_parameters(initial_params)?;
790
791 Ok(TaskResult {
792 support_loss: 0.0,
793 query_loss,
794 query_accuracy,
795 adaptation_time: std::time::Duration::from_millis(0),
796 })
797 },
798 MetaAlgorithm::ProtoNet => self.train_prototypical_task(task),
799 _ => {
800 self.train_single_task(task)
803 },
804 }
805 }
806
807 pub fn get_statistics(&self) -> &MetaStatistics {
809 &self.meta_statistics
810 }
811
812 pub fn get_episode_history(&self) -> &[EpisodeResult] {
814 &self.episode_history
815 }
816
817 pub fn sample_task_batch(&mut self) -> Result<TaskBatch, TrustformersError> {
819 self.task_sampler.sample_batch(self.config.meta_batch_size)
820 }
821}
822
823#[derive(Debug, Clone)]
825pub struct Task {
826 pub task_id: String,
827 pub support_set: ExampleSet,
828 pub query_set: ExampleSet,
829 pub task_type: TaskType,
830}
831
832#[derive(Debug, Clone)]
833pub struct TaskBatch {
834 pub tasks: Vec<Task>,
835 pub batch_id: String,
836}
837
838#[derive(Debug, Clone)]
839pub struct ExampleSet {
840 pub examples: Vec<Example>,
841 pub num_classes: usize,
842}
843
844#[derive(Debug, Clone)]
845pub struct Example {
846 pub input: Tensor,
847 pub label: usize,
848 pub metadata: HashMap<String, String>,
849}
850
851#[derive(Debug, Clone, Copy, PartialEq, Eq)]
852pub enum TaskType {
853 Classification,
854 Regression,
855 Generation,
856 SequenceLabeling,
857}
858
859#[derive(Debug, Clone)]
861pub struct EpisodeResult {
862 pub episode: usize,
863 pub meta_loss: f64,
864 pub meta_accuracy: f64,
865 pub num_tasks: usize,
866 pub episode_time: std::time::Duration,
867 pub algorithm: MetaAlgorithm,
868}
869
870#[derive(Debug, Clone)]
871pub struct TaskResult {
872 pub support_loss: f64,
873 pub query_loss: f64,
874 pub query_accuracy: f64,
875 pub adaptation_time: std::time::Duration,
876}
877
878#[derive(Debug, Clone)]
879pub struct EvaluationResult {
880 pub average_accuracy: f64,
881 pub task_results: Vec<TaskResult>,
882 pub num_tasks: usize,
883}
884
885#[derive(Debug)]
886pub struct MetaStatistics {
887 pub total_episodes: usize,
888 pub average_accuracy: f64,
889 pub best_accuracy: f64,
890 pub recent_accuracies: std::collections::VecDeque<f64>,
891 pub convergence_rate: f64,
892}
893
894impl Default for MetaStatistics {
895 fn default() -> Self {
896 Self::new()
897 }
898}
899
900impl MetaStatistics {
901 pub fn new() -> Self {
902 Self {
903 total_episodes: 0,
904 average_accuracy: 0.0,
905 best_accuracy: 0.0,
906 recent_accuracies: std::collections::VecDeque::with_capacity(100),
907 convergence_rate: 0.0,
908 }
909 }
910
911 pub fn update(&mut self, episode_result: &EpisodeResult) {
912 self.total_episodes += 1;
913
914 let alpha = 0.01; self.average_accuracy =
917 alpha * episode_result.meta_accuracy + (1.0 - alpha) * self.average_accuracy;
918
919 if episode_result.meta_accuracy > self.best_accuracy {
921 self.best_accuracy = episode_result.meta_accuracy;
922 }
923
924 self.recent_accuracies.push_back(episode_result.meta_accuracy);
926 if self.recent_accuracies.len() > 100 {
927 self.recent_accuracies.pop_front();
928 }
929
930 if self.recent_accuracies.len() > 10 {
932 let recent_mean =
933 self.recent_accuracies.iter().sum::<f64>() / self.recent_accuracies.len() as f64;
934 let older_mean = self.recent_accuracies.iter().take(50).sum::<f64>()
935 / (50.0f64).min(self.recent_accuracies.len() as f64);
936 self.convergence_rate = (recent_mean - older_mean).abs();
937 }
938 }
939}
940
941pub trait MetaLearningModel: Send + Sync {
943 fn forward(&mut self, examples: &ExampleSet) -> Result<f64, TrustformersError>;
944 fn compute_accuracy(&self, examples: &ExampleSet) -> Result<f64, TrustformersError>;
945 fn compute_gradients(&self, loss: f64) -> Result<ModelGradients, TrustformersError>;
946 fn apply_gradients(
947 &mut self,
948 gradients: &ModelGradients,
949 lr: f64,
950 ) -> Result<(), TrustformersError>;
951 fn get_parameters(&self) -> Result<ModelParameters, TrustformersError>;
952 fn set_parameters(&mut self, params: ModelParameters) -> Result<(), TrustformersError>;
953 fn embed(&self, example: &Example) -> Result<Tensor, TrustformersError>;
954
955 fn compute_second_order_gradients(
957 &self,
958 _initial_params: &ModelParameters,
959 _loss: f64,
960 ) -> Result<ModelGradients, TrustformersError> {
961 Err(unsupported_operation(
962 "compute_second_order_gradients",
963 "meta_learning",
964 ))
965 }
966
967 fn compute_first_order_gradients(
968 &self,
969 _loss: f64,
970 ) -> Result<ModelGradients, TrustformersError> {
971 Err(unsupported_operation(
972 "compute_first_order_gradients",
973 "meta_learning",
974 ))
975 }
976
977 fn compute_relation(&self, _emb1: &Tensor, _emb2: &Tensor) -> Result<f64, TrustformersError> {
978 Err(unsupported_operation("compute_relation", "meta_learning"))
979 }
980
981 fn write_to_memory(&mut self, _example: &Example) -> Result<(), TrustformersError> {
982 Err(unsupported_operation("write_to_memory", "meta_learning"))
983 }
984
985 fn read_from_memory(&self, _example: &Example) -> Result<MemoryOutput, TrustformersError> {
986 Err(unsupported_operation("read_from_memory", "meta_learning"))
987 }
988
989 fn predict_from_memory(
990 &self,
991 _memory_output: &MemoryOutput,
992 ) -> Result<MemoryPrediction, TrustformersError> {
993 Err(unsupported_operation(
994 "predict_from_memory",
995 "meta_learning",
996 ))
997 }
998
999 fn clear_memory(&mut self) -> Result<(), TrustformersError> {
1000 Ok(())
1001 }
1002
1003 fn get_learning_rates(&self) -> Result<Vec<f64>, TrustformersError> {
1004 Err(unsupported_operation("get_learning_rates", "meta_learning"))
1005 }
1006
1007 fn apply_gradients_with_lr(
1008 &mut self,
1009 _gradients: &ModelGradients,
1010 _learning_rates: &[f64],
1011 ) -> Result<(), TrustformersError> {
1012 Err(unsupported_operation(
1013 "apply_gradients_with_lr",
1014 "meta_learning",
1015 ))
1016 }
1017
1018 fn compute_lr_gradients(&self, _loss: f64) -> Result<Vec<f64>, TrustformersError> {
1019 Err(unsupported_operation(
1020 "compute_lr_gradients",
1021 "meta_learning",
1022 ))
1023 }
1024
1025 fn get_meta_learner_state(&self) -> Result<MetaLearnerState, TrustformersError> {
1026 Err(unsupported_operation(
1027 "get_meta_learner_state",
1028 "meta_learning",
1029 ))
1030 }
1031
1032 fn apply_learned_algorithm(
1033 &self,
1034 _support_set: &ExampleSet,
1035 _state: &MetaLearnerState,
1036 ) -> Result<ModelParameters, TrustformersError> {
1037 Err(unsupported_operation(
1038 "apply_learned_algorithm",
1039 "meta_learning",
1040 ))
1041 }
1042
1043 fn evaluate_with_params(
1044 &self,
1045 _examples: &ExampleSet,
1046 _params: &ModelParameters,
1047 ) -> Result<f64, TrustformersError> {
1048 Err(unsupported_operation(
1049 "evaluate_with_params",
1050 "meta_learning",
1051 ))
1052 }
1053
1054 fn compute_accuracy_with_params(
1055 &self,
1056 _examples: &ExampleSet,
1057 _params: &ModelParameters,
1058 ) -> Result<f64, TrustformersError> {
1059 Err(unsupported_operation(
1060 "compute_accuracy_with_params",
1061 "meta_learning",
1062 ))
1063 }
1064
1065 fn compute_meta_learner_gradients(
1066 &self,
1067 _loss: f64,
1068 ) -> Result<ModelGradients, TrustformersError> {
1069 Err(unsupported_operation(
1070 "compute_meta_learner_gradients",
1071 "meta_learning",
1072 ))
1073 }
1074
1075 fn get_lstm_state(&self) -> Result<LSTMState, TrustformersError> {
1076 Err(unsupported_operation("get_lstm_state", "meta_learning"))
1077 }
1078
1079 fn lstm_update(
1080 &self,
1081 _gradients: &ModelGradients,
1082 _state: &LSTMState,
1083 _step: usize,
1084 ) -> Result<(ModelUpdates, LSTMState), TrustformersError> {
1085 Err(unsupported_operation("lstm_update", "meta_learning"))
1086 }
1087
1088 fn apply_lstm_updates(&mut self, _updates: &ModelUpdates) -> Result<(), TrustformersError> {
1089 Err(unsupported_operation("apply_lstm_updates", "meta_learning"))
1090 }
1091
1092 fn compute_lstm_gradients(&self, _loss: f64) -> Result<ModelGradients, TrustformersError> {
1093 Err(unsupported_operation(
1094 "compute_lstm_gradients",
1095 "meta_learning",
1096 ))
1097 }
1098}
1099
1100pub trait MetaOptimizer: Send + Sync {
1101 fn step(&mut self, model: &mut dyn MetaLearningModel) -> Result<(), TrustformersError>;
1102 fn accumulate_gradients(&mut self, gradients: ModelGradients) -> Result<(), TrustformersError>;
1103 fn accumulate_param_gradients(
1104 &mut self,
1105 _gradients: ModelGradients,
1106 ) -> Result<(), TrustformersError> {
1107 self.accumulate_gradients(_gradients)
1108 }
1109 fn accumulate_lr_gradients(
1110 &mut self,
1111 _lr_gradients: Vec<f64>,
1112 ) -> Result<(), TrustformersError> {
1113 Ok(())
1114 }
1115 fn reset(&mut self) -> Result<(), TrustformersError>;
1116}
1117
1118#[derive(Debug, Clone)]
1120pub struct ModelParameters {
1121 pub parameters: HashMap<String, Tensor>,
1122}
1123
1124#[derive(Debug, Clone)]
1125pub struct ModelGradients {
1126 pub gradients: HashMap<String, Tensor>,
1127}
1128
1129impl Default for ModelGradients {
1130 fn default() -> Self {
1131 Self::new()
1132 }
1133}
1134
1135impl ModelGradients {
1136 pub fn new() -> Self {
1137 Self {
1138 gradients: HashMap::new(),
1139 }
1140 }
1141}
1142
1143#[derive(Debug, Clone)]
1144pub struct MemoryOutput {
1145 pub content: Tensor,
1146 pub attention_weights: Vec<f64>,
1147}
1148
1149#[derive(Debug, Clone)]
1150pub struct MemoryPrediction {
1151 pub logits: Tensor,
1152 pub confidence: f64,
1153}
1154
1155#[derive(Debug, Clone)]
1156pub struct MetaLearnerState {
1157 pub hidden_state: Tensor,
1158 pub cell_state: Tensor,
1159}
1160
1161#[derive(Debug, Clone)]
1162pub struct LSTMState {
1163 pub hidden: Tensor,
1164 pub cell: Tensor,
1165}
1166
1167#[derive(Debug, Clone)]
1168pub struct ModelUpdates {
1169 pub updates: HashMap<String, Tensor>,
1170}
1171
1172pub struct TaskSampler {
1174 config: MetaLearningConfig,
1175 #[allow(dead_code)]
1176 task_distributions: Vec<TaskDistribution>,
1177 current_task_id: usize,
1178}
1179
1180impl TaskSampler {
1181 pub fn new(config: &MetaLearningConfig) -> Result<Self, TrustformersError> {
1182 Ok(Self {
1183 config: config.clone(),
1184 task_distributions: Vec::new(),
1185 current_task_id: 0,
1186 })
1187 }
1188
1189 pub fn sample_batch(&mut self, batch_size: usize) -> Result<TaskBatch, TrustformersError> {
1190 let mut tasks = Vec::new();
1191
1192 for _ in 0..batch_size {
1193 let task = self.sample_single_task()?;
1194 tasks.push(task);
1195 }
1196
1197 Ok(TaskBatch {
1198 tasks,
1199 batch_id: format!("batch_{}", self.current_task_id),
1200 })
1201 }
1202
1203 fn sample_single_task(&mut self) -> Result<Task, TrustformersError> {
1204 let support_set = self.create_example_set(self.config.support_size)?;
1206 let query_set = self.create_example_set(self.config.query_size)?;
1207
1208 self.current_task_id += 1;
1209
1210 Ok(Task {
1211 task_id: format!("task_{}", self.current_task_id),
1212 support_set,
1213 query_set,
1214 task_type: TaskType::Classification,
1215 })
1216 }
1217
1218 fn create_example_set(&self, size: usize) -> Result<ExampleSet, TrustformersError> {
1219 let mut examples = Vec::new();
1220
1221 for i in 0..size {
1222 let input = Tensor::randn(&[self.config.embedding_dim])?;
1223 let label = i % self.config.num_ways; examples.push(Example {
1226 input,
1227 label,
1228 metadata: HashMap::new(),
1229 });
1230 }
1231
1232 Ok(ExampleSet {
1233 examples,
1234 num_classes: self.config.num_ways,
1235 })
1236 }
1237}
1238
1239#[derive(Debug)]
1240pub struct TaskDistribution {
1241 pub name: String,
1242 pub sampling_weight: f64,
1243}
1244
1245pub struct MAMLModel {
1248 #[allow(dead_code)]
1249 config: MetaLearningConfig,
1250}
1251
1252impl MAMLModel {
1253 pub fn new(config: &MetaLearningConfig) -> Result<Self, TrustformersError> {
1254 Ok(Self {
1255 config: config.clone(),
1256 })
1257 }
1258}
1259
1260impl MetaLearningModel for MAMLModel {
1261 fn forward(&mut self, _examples: &ExampleSet) -> Result<f64, TrustformersError> {
1262 Ok(0.5) }
1264
1265 fn compute_accuracy(&self, _examples: &ExampleSet) -> Result<f64, TrustformersError> {
1266 Ok(0.8) }
1268
1269 fn compute_gradients(&self, _loss: f64) -> Result<ModelGradients, TrustformersError> {
1270 Ok(ModelGradients::new())
1271 }
1272
1273 fn apply_gradients(
1274 &mut self,
1275 _gradients: &ModelGradients,
1276 _lr: f64,
1277 ) -> Result<(), TrustformersError> {
1278 Ok(())
1279 }
1280
1281 fn get_parameters(&self) -> Result<ModelParameters, TrustformersError> {
1282 Ok(ModelParameters {
1283 parameters: HashMap::new(),
1284 })
1285 }
1286
1287 fn set_parameters(&mut self, _params: ModelParameters) -> Result<(), TrustformersError> {
1288 Ok(())
1289 }
1290
1291 fn embed(&self, example: &Example) -> Result<Tensor, TrustformersError> {
1292 Ok(example.input.clone())
1293 }
1294
1295 fn compute_second_order_gradients(
1296 &self,
1297 _initial_params: &ModelParameters,
1298 _loss: f64,
1299 ) -> Result<ModelGradients, TrustformersError> {
1300 Ok(ModelGradients::new())
1301 }
1302
1303 fn compute_first_order_gradients(
1304 &self,
1305 _loss: f64,
1306 ) -> Result<ModelGradients, TrustformersError> {
1307 Ok(ModelGradients::new())
1308 }
1309}
1310
1311pub struct ReptileModel {
1313 #[allow(dead_code)]
1314 config: MetaLearningConfig,
1315}
1316
1317impl ReptileModel {
1318 pub fn new(config: &MetaLearningConfig) -> Result<Self, TrustformersError> {
1319 Ok(Self {
1320 config: config.clone(),
1321 })
1322 }
1323}
1324
1325impl MetaLearningModel for ReptileModel {
1326 fn forward(&mut self, _examples: &ExampleSet) -> Result<f64, TrustformersError> {
1327 Ok(0.5)
1328 }
1329 fn compute_accuracy(&self, _examples: &ExampleSet) -> Result<f64, TrustformersError> {
1330 Ok(0.8)
1331 }
1332 fn compute_gradients(&self, _loss: f64) -> Result<ModelGradients, TrustformersError> {
1333 Ok(ModelGradients::new())
1334 }
1335 fn apply_gradients(
1336 &mut self,
1337 _gradients: &ModelGradients,
1338 _lr: f64,
1339 ) -> Result<(), TrustformersError> {
1340 Ok(())
1341 }
1342 fn get_parameters(&self) -> Result<ModelParameters, TrustformersError> {
1343 Ok(ModelParameters {
1344 parameters: HashMap::new(),
1345 })
1346 }
1347 fn set_parameters(&mut self, _params: ModelParameters) -> Result<(), TrustformersError> {
1348 Ok(())
1349 }
1350 fn embed(&self, example: &Example) -> Result<Tensor, TrustformersError> {
1351 Ok(example.input.clone())
1352 }
1353}
1354
1355pub struct PrototypicalModel {
1356 #[allow(dead_code)]
1357 config: MetaLearningConfig,
1358}
1359impl PrototypicalModel {
1360 pub fn new(config: &MetaLearningConfig) -> Result<Self, TrustformersError> {
1361 Ok(Self {
1362 config: config.clone(),
1363 })
1364 }
1365}
1366impl MetaLearningModel for PrototypicalModel {
1367 fn forward(&mut self, _examples: &ExampleSet) -> Result<f64, TrustformersError> {
1368 Ok(0.5)
1369 }
1370 fn compute_accuracy(&self, _examples: &ExampleSet) -> Result<f64, TrustformersError> {
1371 Ok(0.8)
1372 }
1373 fn compute_gradients(&self, _loss: f64) -> Result<ModelGradients, TrustformersError> {
1374 Ok(ModelGradients::new())
1375 }
1376 fn apply_gradients(
1377 &mut self,
1378 _gradients: &ModelGradients,
1379 _lr: f64,
1380 ) -> Result<(), TrustformersError> {
1381 Ok(())
1382 }
1383 fn get_parameters(&self) -> Result<ModelParameters, TrustformersError> {
1384 Ok(ModelParameters {
1385 parameters: HashMap::new(),
1386 })
1387 }
1388 fn set_parameters(&mut self, _params: ModelParameters) -> Result<(), TrustformersError> {
1389 Ok(())
1390 }
1391 fn embed(&self, example: &Example) -> Result<Tensor, TrustformersError> {
1392 Ok(example.input.clone())
1393 }
1394}
1395
1396pub struct MatchingNetModel {
1397 #[allow(dead_code)]
1398 config: MetaLearningConfig,
1399}
1400impl MatchingNetModel {
1401 pub fn new(config: &MetaLearningConfig) -> Result<Self, TrustformersError> {
1402 Ok(Self {
1403 config: config.clone(),
1404 })
1405 }
1406}
1407impl MetaLearningModel for MatchingNetModel {
1408 fn forward(&mut self, _examples: &ExampleSet) -> Result<f64, TrustformersError> {
1409 Ok(0.5)
1410 }
1411 fn compute_accuracy(&self, _examples: &ExampleSet) -> Result<f64, TrustformersError> {
1412 Ok(0.8)
1413 }
1414 fn compute_gradients(&self, _loss: f64) -> Result<ModelGradients, TrustformersError> {
1415 Ok(ModelGradients::new())
1416 }
1417 fn apply_gradients(
1418 &mut self,
1419 _gradients: &ModelGradients,
1420 _lr: f64,
1421 ) -> Result<(), TrustformersError> {
1422 Ok(())
1423 }
1424 fn get_parameters(&self) -> Result<ModelParameters, TrustformersError> {
1425 Ok(ModelParameters {
1426 parameters: HashMap::new(),
1427 })
1428 }
1429 fn set_parameters(&mut self, _params: ModelParameters) -> Result<(), TrustformersError> {
1430 Ok(())
1431 }
1432 fn embed(&self, example: &Example) -> Result<Tensor, TrustformersError> {
1433 Ok(example.input.clone())
1434 }
1435}
1436
1437pub struct RelationNetModel {
1438 #[allow(dead_code)]
1439 config: MetaLearningConfig,
1440}
1441impl RelationNetModel {
1442 pub fn new(config: &MetaLearningConfig) -> Result<Self, TrustformersError> {
1443 Ok(Self {
1444 config: config.clone(),
1445 })
1446 }
1447}
1448impl MetaLearningModel for RelationNetModel {
1449 fn forward(&mut self, _examples: &ExampleSet) -> Result<f64, TrustformersError> {
1450 Ok(0.5)
1451 }
1452 fn compute_accuracy(&self, _examples: &ExampleSet) -> Result<f64, TrustformersError> {
1453 Ok(0.8)
1454 }
1455 fn compute_gradients(&self, _loss: f64) -> Result<ModelGradients, TrustformersError> {
1456 Ok(ModelGradients::new())
1457 }
1458 fn apply_gradients(
1459 &mut self,
1460 _gradients: &ModelGradients,
1461 _lr: f64,
1462 ) -> Result<(), TrustformersError> {
1463 Ok(())
1464 }
1465 fn get_parameters(&self) -> Result<ModelParameters, TrustformersError> {
1466 Ok(ModelParameters {
1467 parameters: HashMap::new(),
1468 })
1469 }
1470 fn set_parameters(&mut self, _params: ModelParameters) -> Result<(), TrustformersError> {
1471 Ok(())
1472 }
1473 fn embed(&self, example: &Example) -> Result<Tensor, TrustformersError> {
1474 Ok(example.input.clone())
1475 }
1476 fn compute_relation(&self, _emb1: &Tensor, _emb2: &Tensor) -> Result<f64, TrustformersError> {
1477 Ok(0.5)
1478 }
1479}
1480
1481pub struct MemoryAugmentedModel {
1482 #[allow(dead_code)]
1483 config: MetaLearningConfig,
1484}
1485impl MemoryAugmentedModel {
1486 pub fn new(config: &MetaLearningConfig) -> Result<Self, TrustformersError> {
1487 Ok(Self {
1488 config: config.clone(),
1489 })
1490 }
1491}
1492impl MetaLearningModel for MemoryAugmentedModel {
1493 fn forward(&mut self, _examples: &ExampleSet) -> Result<f64, TrustformersError> {
1494 Ok(0.5)
1495 }
1496 fn compute_accuracy(&self, _examples: &ExampleSet) -> Result<f64, TrustformersError> {
1497 Ok(0.8)
1498 }
1499 fn compute_gradients(&self, _loss: f64) -> Result<ModelGradients, TrustformersError> {
1500 Ok(ModelGradients::new())
1501 }
1502 fn apply_gradients(
1503 &mut self,
1504 _gradients: &ModelGradients,
1505 _lr: f64,
1506 ) -> Result<(), TrustformersError> {
1507 Ok(())
1508 }
1509 fn get_parameters(&self) -> Result<ModelParameters, TrustformersError> {
1510 Ok(ModelParameters {
1511 parameters: HashMap::new(),
1512 })
1513 }
1514 fn set_parameters(&mut self, _params: ModelParameters) -> Result<(), TrustformersError> {
1515 Ok(())
1516 }
1517 fn embed(&self, example: &Example) -> Result<Tensor, TrustformersError> {
1518 Ok(example.input.clone())
1519 }
1520 fn write_to_memory(&mut self, _example: &Example) -> Result<(), TrustformersError> {
1521 Ok(())
1522 }
1523 fn read_from_memory(&self, _example: &Example) -> Result<MemoryOutput, TrustformersError> {
1524 Ok(MemoryOutput {
1525 content: Tensor::zeros(&[64])?,
1526 attention_weights: vec![1.0],
1527 })
1528 }
1529 fn predict_from_memory(
1530 &self,
1531 _memory_output: &MemoryOutput,
1532 ) -> Result<MemoryPrediction, TrustformersError> {
1533 Ok(MemoryPrediction {
1534 logits: Tensor::zeros(&[5])?,
1535 confidence: 0.8,
1536 })
1537 }
1538}
1539
1540pub struct GradientBasedModel {
1541 #[allow(dead_code)]
1542 config: MetaLearningConfig,
1543}
1544impl GradientBasedModel {
1545 pub fn new(config: &MetaLearningConfig) -> Result<Self, TrustformersError> {
1546 Ok(Self {
1547 config: config.clone(),
1548 })
1549 }
1550}
1551impl MetaLearningModel for GradientBasedModel {
1552 fn forward(&mut self, _examples: &ExampleSet) -> Result<f64, TrustformersError> {
1553 Ok(0.5)
1554 }
1555 fn compute_accuracy(&self, _examples: &ExampleSet) -> Result<f64, TrustformersError> {
1556 Ok(0.8)
1557 }
1558 fn compute_gradients(&self, _loss: f64) -> Result<ModelGradients, TrustformersError> {
1559 Ok(ModelGradients::new())
1560 }
1561 fn apply_gradients(
1562 &mut self,
1563 _gradients: &ModelGradients,
1564 _lr: f64,
1565 ) -> Result<(), TrustformersError> {
1566 Ok(())
1567 }
1568 fn get_parameters(&self) -> Result<ModelParameters, TrustformersError> {
1569 Ok(ModelParameters {
1570 parameters: HashMap::new(),
1571 })
1572 }
1573 fn set_parameters(&mut self, _params: ModelParameters) -> Result<(), TrustformersError> {
1574 Ok(())
1575 }
1576 fn embed(&self, example: &Example) -> Result<Tensor, TrustformersError> {
1577 Ok(example.input.clone())
1578 }
1579}
1580
1581pub struct MetaSGDModel {
1582 #[allow(dead_code)]
1583 config: MetaLearningConfig,
1584}
1585impl MetaSGDModel {
1586 pub fn new(config: &MetaLearningConfig) -> Result<Self, TrustformersError> {
1587 Ok(Self {
1588 config: config.clone(),
1589 })
1590 }
1591}
1592impl MetaLearningModel for MetaSGDModel {
1593 fn forward(&mut self, _examples: &ExampleSet) -> Result<f64, TrustformersError> {
1594 Ok(0.5)
1595 }
1596 fn compute_accuracy(&self, _examples: &ExampleSet) -> Result<f64, TrustformersError> {
1597 Ok(0.8)
1598 }
1599 fn compute_gradients(&self, _loss: f64) -> Result<ModelGradients, TrustformersError> {
1600 Ok(ModelGradients::new())
1601 }
1602 fn apply_gradients(
1603 &mut self,
1604 _gradients: &ModelGradients,
1605 _lr: f64,
1606 ) -> Result<(), TrustformersError> {
1607 Ok(())
1608 }
1609 fn get_parameters(&self) -> Result<ModelParameters, TrustformersError> {
1610 Ok(ModelParameters {
1611 parameters: HashMap::new(),
1612 })
1613 }
1614 fn set_parameters(&mut self, _params: ModelParameters) -> Result<(), TrustformersError> {
1615 Ok(())
1616 }
1617 fn embed(&self, example: &Example) -> Result<Tensor, TrustformersError> {
1618 Ok(example.input.clone())
1619 }
1620}
1621
1622pub struct L2LModel {
1623 #[allow(dead_code)]
1624 config: MetaLearningConfig,
1625}
1626impl L2LModel {
1627 pub fn new(config: &MetaLearningConfig) -> Result<Self, TrustformersError> {
1628 Ok(Self {
1629 config: config.clone(),
1630 })
1631 }
1632}
1633impl MetaLearningModel for L2LModel {
1634 fn forward(&mut self, _examples: &ExampleSet) -> Result<f64, TrustformersError> {
1635 Ok(0.5)
1636 }
1637 fn compute_accuracy(&self, _examples: &ExampleSet) -> Result<f64, TrustformersError> {
1638 Ok(0.8)
1639 }
1640 fn compute_gradients(&self, _loss: f64) -> Result<ModelGradients, TrustformersError> {
1641 Ok(ModelGradients::new())
1642 }
1643 fn apply_gradients(
1644 &mut self,
1645 _gradients: &ModelGradients,
1646 _lr: f64,
1647 ) -> Result<(), TrustformersError> {
1648 Ok(())
1649 }
1650 fn get_parameters(&self) -> Result<ModelParameters, TrustformersError> {
1651 Ok(ModelParameters {
1652 parameters: HashMap::new(),
1653 })
1654 }
1655 fn set_parameters(&mut self, _params: ModelParameters) -> Result<(), TrustformersError> {
1656 Ok(())
1657 }
1658 fn embed(&self, example: &Example) -> Result<Tensor, TrustformersError> {
1659 Ok(example.input.clone())
1660 }
1661}
1662
1663pub struct SGDMetaOptimizer {
1665 learning_rate: f64,
1666 accumulated_gradients: Option<ModelGradients>,
1667}
1668
1669impl SGDMetaOptimizer {
1670 pub fn new(learning_rate: f64) -> Result<Self, TrustformersError> {
1671 Ok(Self {
1672 learning_rate,
1673 accumulated_gradients: None,
1674 })
1675 }
1676}
1677
1678impl MetaOptimizer for SGDMetaOptimizer {
1679 fn step(&mut self, model: &mut dyn MetaLearningModel) -> Result<(), TrustformersError> {
1680 if let Some(gradients) = &self.accumulated_gradients {
1681 model.apply_gradients(gradients, self.learning_rate)?;
1682 self.accumulated_gradients = None;
1683 }
1684 Ok(())
1685 }
1686
1687 fn accumulate_gradients(&mut self, gradients: ModelGradients) -> Result<(), TrustformersError> {
1688 self.accumulated_gradients = Some(gradients);
1689 Ok(())
1690 }
1691
1692 fn reset(&mut self) -> Result<(), TrustformersError> {
1693 self.accumulated_gradients = None;
1694 Ok(())
1695 }
1696}
1697
1698pub struct AdamMetaOptimizer {
1699 learning_rate: f64,
1700 accumulated_gradients: Option<ModelGradients>,
1701}
1702
1703impl AdamMetaOptimizer {
1704 pub fn new(learning_rate: f64) -> Result<Self, TrustformersError> {
1705 Ok(Self {
1706 learning_rate,
1707 accumulated_gradients: None,
1708 })
1709 }
1710}
1711
1712impl MetaOptimizer for AdamMetaOptimizer {
1713 fn step(&mut self, model: &mut dyn MetaLearningModel) -> Result<(), TrustformersError> {
1714 if let Some(gradients) = &self.accumulated_gradients {
1715 model.apply_gradients(gradients, self.learning_rate)?;
1716 self.accumulated_gradients = None;
1717 }
1718 Ok(())
1719 }
1720
1721 fn accumulate_gradients(&mut self, gradients: ModelGradients) -> Result<(), TrustformersError> {
1722 self.accumulated_gradients = Some(gradients);
1723 Ok(())
1724 }
1725
1726 fn reset(&mut self) -> Result<(), TrustformersError> {
1727 self.accumulated_gradients = None;
1728 Ok(())
1729 }
1730}
1731
1732pub struct LearnedLROptimizer {
1733 learning_rate: f64,
1734 accumulated_gradients: Option<ModelGradients>,
1735 accumulated_lr_gradients: Option<Vec<f64>>,
1736}
1737
1738impl LearnedLROptimizer {
1739 pub fn new(learning_rate: f64) -> Result<Self, TrustformersError> {
1740 Ok(Self {
1741 learning_rate,
1742 accumulated_gradients: None,
1743 accumulated_lr_gradients: None,
1744 })
1745 }
1746}
1747
1748impl MetaOptimizer for LearnedLROptimizer {
1749 fn step(&mut self, model: &mut dyn MetaLearningModel) -> Result<(), TrustformersError> {
1750 if let Some(gradients) = &self.accumulated_gradients {
1751 model.apply_gradients(gradients, self.learning_rate)?;
1752 self.accumulated_gradients = None;
1753 }
1754 Ok(())
1755 }
1756
1757 fn accumulate_gradients(&mut self, gradients: ModelGradients) -> Result<(), TrustformersError> {
1758 self.accumulated_gradients = Some(gradients);
1759 Ok(())
1760 }
1761
1762 fn accumulate_lr_gradients(&mut self, lr_gradients: Vec<f64>) -> Result<(), TrustformersError> {
1763 self.accumulated_lr_gradients = Some(lr_gradients);
1764 Ok(())
1765 }
1766
1767 fn reset(&mut self) -> Result<(), TrustformersError> {
1768 self.accumulated_gradients = None;
1769 self.accumulated_lr_gradients = None;
1770 Ok(())
1771 }
1772}
1773
1774pub mod utils {
1776 use super::*;
1777
1778 pub fn create_few_shot_config(
1780 num_ways: usize,
1781 num_shots: usize,
1782 query_size: usize,
1783 ) -> MetaLearningConfig {
1784 MetaLearningConfig {
1785 num_ways,
1786 num_shots,
1787 support_size: num_ways * num_shots,
1788 query_size,
1789 ..Default::default()
1790 }
1791 }
1792
1793 pub fn create_maml_config() -> MetaLearningConfig {
1795 MetaLearningConfig {
1796 algorithm: MetaAlgorithm::MAML,
1797 inner_lr: 0.01,
1798 meta_lr: 0.001,
1799 inner_steps: 5,
1800 first_order: false,
1801 ..Default::default()
1802 }
1803 }
1804
1805 pub fn create_reptile_config() -> MetaLearningConfig {
1807 MetaLearningConfig {
1808 algorithm: MetaAlgorithm::Reptile,
1809 inner_lr: 0.01,
1810 meta_lr: 0.001,
1811 inner_steps: 10,
1812 first_order: true,
1813 ..Default::default()
1814 }
1815 }
1816
1817 pub fn create_protonet_config() -> MetaLearningConfig {
1819 MetaLearningConfig {
1820 algorithm: MetaAlgorithm::ProtoNet,
1821 temperature: 1.0,
1822 normalize_embeddings: true,
1823 embedding_dim: 512,
1824 ..Default::default()
1825 }
1826 }
1827
1828 pub fn calculate_performance_metrics(episode_results: &[EpisodeResult]) -> PerformanceMetrics {
1830 if episode_results.is_empty() {
1831 return PerformanceMetrics::default();
1832 }
1833
1834 let accuracies: Vec<f64> = episode_results.iter().map(|r| r.meta_accuracy).collect();
1835 let mean_accuracy = accuracies.iter().sum::<f64>() / accuracies.len() as f64;
1836
1837 let variance = accuracies.iter().map(|acc| (acc - mean_accuracy).powi(2)).sum::<f64>()
1838 / accuracies.len() as f64;
1839 let std_dev = variance.sqrt();
1840
1841 let max_accuracy = accuracies.iter().fold(0.0f64, |a, &b| a.max(b));
1842 let min_accuracy = accuracies.iter().fold(1.0f64, |a, &b| a.min(b));
1843
1844 PerformanceMetrics {
1845 mean_accuracy,
1846 std_dev,
1847 max_accuracy,
1848 min_accuracy,
1849 num_episodes: episode_results.len(),
1850 }
1851 }
1852
1853 pub fn estimate_convergence(
1855 episode_results: &[EpisodeResult],
1856 window_size: usize,
1857 ) -> ConvergenceMetrics {
1858 if episode_results.len() < window_size * 2 {
1859 return ConvergenceMetrics::default();
1860 }
1861
1862 let recent_window = &episode_results[episode_results.len() - window_size..];
1863 let older_window = &episode_results
1864 [episode_results.len() - window_size * 2..episode_results.len() - window_size];
1865
1866 let recent_mean =
1867 recent_window.iter().map(|r| r.meta_accuracy).sum::<f64>() / window_size as f64;
1868 let older_mean =
1869 older_window.iter().map(|r| r.meta_accuracy).sum::<f64>() / window_size as f64;
1870
1871 let improvement_rate = recent_mean - older_mean;
1872 let has_converged = improvement_rate.abs() < 0.001;
1873
1874 ConvergenceMetrics {
1875 improvement_rate,
1876 has_converged,
1877 recent_mean,
1878 older_mean,
1879 }
1880 }
1881}
1882
1883#[derive(Debug, Default)]
1884pub struct PerformanceMetrics {
1885 pub mean_accuracy: f64,
1886 pub std_dev: f64,
1887 pub max_accuracy: f64,
1888 pub min_accuracy: f64,
1889 pub num_episodes: usize,
1890}
1891
1892#[derive(Debug, Default)]
1893pub struct ConvergenceMetrics {
1894 pub improvement_rate: f64,
1895 pub has_converged: bool,
1896 pub recent_mean: f64,
1897 pub older_mean: f64,
1898}
1899
1900#[cfg(test)]
1901mod tests {
1902 use super::*;
1903
1904 #[test]
1905 fn test_meta_learning_config_default() {
1906 let config = MetaLearningConfig::default();
1907 assert_eq!(config.algorithm, MetaAlgorithm::MAML);
1908 assert_eq!(config.num_ways, 5);
1909 assert_eq!(config.num_shots, 1);
1910 }
1911
1912 #[test]
1913 fn test_meta_learner_creation() {
1914 let config = MetaLearningConfig::default();
1915 let result = MetaLearner::new(config);
1916 assert!(result.is_ok());
1917 }
1918
1919 #[test]
1920 fn test_task_sampler() {
1921 let config = MetaLearningConfig::default();
1922 let mut sampler = TaskSampler::new(&config).expect("operation failed");
1923 let task_batch = sampler.sample_batch(4).expect("operation failed");
1924 assert_eq!(task_batch.tasks.len(), 4);
1925 }
1926
1927 #[test]
1928 fn test_meta_statistics() {
1929 let mut stats = MetaStatistics::new();
1930 let episode_result = EpisodeResult {
1931 episode: 0,
1932 meta_loss: 0.5,
1933 meta_accuracy: 0.8,
1934 num_tasks: 10,
1935 episode_time: std::time::Duration::from_millis(100),
1936 algorithm: MetaAlgorithm::MAML,
1937 };
1938
1939 stats.update(&episode_result);
1940 assert!(stats.total_episodes > 0);
1941 assert!(stats.best_accuracy > 0.0);
1942 }
1943
1944 #[test]
1945 fn test_utils_few_shot_config() {
1946 let config = utils::create_few_shot_config(5, 1, 15);
1947 assert_eq!(config.num_ways, 5);
1948 assert_eq!(config.num_shots, 1);
1949 assert_eq!(config.support_size, 5);
1950 assert_eq!(config.query_size, 15);
1951 }
1952
1953 #[test]
1954 fn test_meta_algorithms() {
1955 assert_ne!(MetaAlgorithm::MAML, MetaAlgorithm::Reptile);
1956 assert_eq!(MetaAlgorithm::ProtoNet as u8, 2);
1957 }
1958
1959 #[test]
1960 fn test_performance_metrics_calculation() {
1961 let episode_results = vec![
1962 EpisodeResult {
1963 episode: 0,
1964 meta_loss: 0.5,
1965 meta_accuracy: 0.8,
1966 num_tasks: 10,
1967 episode_time: std::time::Duration::from_millis(100),
1968 algorithm: MetaAlgorithm::MAML,
1969 },
1970 EpisodeResult {
1971 episode: 1,
1972 meta_loss: 0.4,
1973 meta_accuracy: 0.85,
1974 num_tasks: 10,
1975 episode_time: std::time::Duration::from_millis(100),
1976 algorithm: MetaAlgorithm::MAML,
1977 },
1978 ];
1979
1980 let metrics = utils::calculate_performance_metrics(&episode_results);
1981 assert!(metrics.mean_accuracy > 0.8);
1982 assert!(metrics.std_dev >= 0.0);
1983 assert_eq!(metrics.num_episodes, 2);
1984 }
1985}