Skip to main content

tensorlogic_train/
meta_learning.rs

1//! Meta-learning algorithms for learning to learn.
2//!
3//! This module implements meta-learning algorithms that learn model initializations
4//! or update rules that enable rapid adaptation to new tasks with minimal data.
5//!
6//! # Overview
7//!
8//! Meta-learning (or "learning to learn") aims to improve a model's ability to
9//! quickly adapt to new tasks by learning across multiple related tasks. This module
10//! provides implementations of state-of-the-art meta-learning algorithms:
11//!
12//! - **MAML** (Model-Agnostic Meta-Learning): Learns an initialization that can
13//!   quickly adapt via a few gradient steps
14//! - **Reptile**: Simpler first-order approximation that directly moves toward
15//!   task-specific parameters
16//! - **Task sampling**: Infrastructure for episodic meta-learning
17//!
18//! # Key Concepts
19//!
20//! - **Meta-training**: Outer loop that updates the meta-parameters
21//! - **Task adaptation**: Inner loop that adapts to specific tasks
22//! - **Support set**: Training data for task adaptation
23//! - **Query set**: Validation data for meta-objective
24//!
25//! # Examples
26//!
27//! ```rust
28//! use tensorlogic_train::{MAMLConfig, ReptileConfig, MetaLearner};
29//!
30//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
31//! // Configure MAML
32//! let maml_config = MAMLConfig {
33//!     inner_steps: 5,
34//!     inner_lr: 0.01,
35//!     outer_lr: 0.001,
36//!     first_order: false,
37//! };
38//!
39//! // Configure Reptile
40//! let reptile_config = ReptileConfig {
41//!     inner_steps: 5,
42//!     inner_lr: 0.01,
43//!     outer_lr: 0.001,
44//! };
45//! # Ok(())
46//! # }
47//! ```
48
49use crate::{TrainError, TrainResult};
50use scirs2_core::ndarray::{Array1, Array2};
51use std::collections::HashMap;
52
53/// MAML (Model-Agnostic Meta-Learning) configuration.
54///
55/// MAML learns an initialization θ* such that a small number of gradient steps
56/// on a new task yields good performance.
57#[derive(Debug, Clone)]
58pub struct MAMLConfig {
59    /// Number of gradient steps for task adaptation (inner loop).
60    pub inner_steps: usize,
61    /// Learning rate for task adaptation (inner loop).
62    pub inner_lr: f64,
63    /// Learning rate for meta-update (outer loop).
64    pub outer_lr: f64,
65    /// Use first-order approximation (ignores second derivatives).
66    pub first_order: bool,
67}
68
69impl Default for MAMLConfig {
70    fn default() -> Self {
71        Self {
72            inner_steps: 5,
73            inner_lr: 0.01,
74            outer_lr: 0.001,
75            first_order: false,
76        }
77    }
78}
79
80/// Reptile algorithm configuration.
81///
82/// Reptile is a simpler first-order alternative to MAML that repeatedly:
83/// 1. Samples a task
84/// 2. Trains on it to get task-specific parameters
85/// 3. Moves the meta-parameters toward the task-specific parameters
86#[derive(Debug, Clone)]
87pub struct ReptileConfig {
88    /// Number of gradient steps for task adaptation.
89    pub inner_steps: usize,
90    /// Learning rate for task adaptation.
91    pub inner_lr: f64,
92    /// Learning rate for meta-update (interpolation weight).
93    pub outer_lr: f64,
94}
95
96impl Default for ReptileConfig {
97    fn default() -> Self {
98        Self {
99            inner_steps: 10,
100            inner_lr: 0.01,
101            outer_lr: 0.1,
102        }
103    }
104}
105
106/// Meta-learning task representation.
107///
108/// Each task consists of a support set (for adaptation) and
109/// a query set (for evaluation).
110#[derive(Debug, Clone)]
111pub struct MetaTask {
112    /// Support set features (for training/adaptation).
113    pub support_x: Array2<f64>,
114    /// Support set labels.
115    pub support_y: Array2<f64>,
116    /// Query set features (for evaluation).
117    pub query_x: Array2<f64>,
118    /// Query set labels.
119    pub query_y: Array2<f64>,
120}
121
122impl MetaTask {
123    /// Create a new meta-learning task.
124    pub fn new(
125        support_x: Array2<f64>,
126        support_y: Array2<f64>,
127        query_x: Array2<f64>,
128        query_y: Array2<f64>,
129    ) -> TrainResult<Self> {
130        if support_x.nrows() != support_y.nrows() {
131            return Err(TrainError::InvalidParameter(format!(
132                "Support X rows ({}) must match support Y rows ({})",
133                support_x.nrows(),
134                support_y.nrows()
135            )));
136        }
137
138        if query_x.nrows() != query_y.nrows() {
139            return Err(TrainError::InvalidParameter(format!(
140                "Query X rows ({}) must match query Y rows ({})",
141                query_x.nrows(),
142                query_y.nrows()
143            )));
144        }
145
146        Ok(Self {
147            support_x,
148            support_y,
149            query_x,
150            query_y,
151        })
152    }
153
154    /// Get support set size.
155    pub fn support_size(&self) -> usize {
156        self.support_x.nrows()
157    }
158
159    /// Get query set size.
160    pub fn query_size(&self) -> usize {
161        self.query_x.nrows()
162    }
163}
164
165/// Meta-learner trait for different meta-learning algorithms.
166pub trait MetaLearner {
167    /// Perform one step of meta-training on a batch of tasks.
168    ///
169    /// # Arguments
170    /// * `tasks` - Batch of meta-learning tasks
171    /// * `parameters` - Current meta-parameters
172    ///
173    /// # Returns
174    /// Updated meta-parameters and meta-loss
175    fn meta_step(
176        &self,
177        tasks: &[MetaTask],
178        parameters: &HashMap<String, Array1<f64>>,
179    ) -> TrainResult<(HashMap<String, Array1<f64>>, f64)>;
180
181    /// Adapt parameters to a specific task (inner loop).
182    ///
183    /// # Arguments
184    /// * `task` - Task to adapt to
185    /// * `parameters` - Initial parameters
186    ///
187    /// # Returns
188    /// Task-adapted parameters
189    fn adapt(
190        &self,
191        task: &MetaTask,
192        parameters: &HashMap<String, Array1<f64>>,
193    ) -> TrainResult<HashMap<String, Array1<f64>>>;
194}
195
196/// MAML (Model-Agnostic Meta-Learning) implementation.
197///
198/// MAML optimizes for a model initialization that can quickly adapt
199/// to new tasks through a few gradient steps.
200#[derive(Debug, Clone)]
201pub struct MAML {
202    config: MAMLConfig,
203}
204
205impl MAML {
206    /// Create a new MAML meta-learner.
207    pub fn new(config: MAMLConfig) -> Self {
208        Self { config }
209    }
210}
211
212impl Default for MAML {
213    fn default() -> Self {
214        Self::new(MAMLConfig::default())
215    }
216}
217
218impl MetaLearner for MAML {
219    fn meta_step(
220        &self,
221        tasks: &[MetaTask],
222        parameters: &HashMap<String, Array1<f64>>,
223    ) -> TrainResult<(HashMap<String, Array1<f64>>, f64)> {
224        let mut meta_gradients: HashMap<String, Array1<f64>> = HashMap::new();
225        let mut total_loss = 0.0;
226
227        // Initialize meta-gradients to zero
228        for (name, param) in parameters {
229            meta_gradients.insert(name.clone(), Array1::zeros(param.len()));
230        }
231
232        // For each task in the batch
233        for task in tasks {
234            // 1. Adapt to the task (inner loop)
235            let adapted_params = self.adapt(task, parameters)?;
236
237            // 2. Compute loss on query set with adapted parameters
238            // This is a placeholder - in practice, you'd use your model here
239            let query_loss = self.compute_query_loss(task, &adapted_params)?;
240            total_loss += query_loss;
241
242            // 3. Compute gradients of query loss w.r.t. meta-parameters
243            // In full MAML, we'd backprop through the adaptation process
244            // For first-order MAML, we use adapted params directly
245            let task_gradients = if self.config.first_order {
246                self.compute_first_order_gradients(task, &adapted_params)?
247            } else {
248                self.compute_second_order_gradients(task, parameters, &adapted_params)?
249            };
250
251            // 4. Accumulate gradients
252            for (name, grad) in task_gradients {
253                if let Some(meta_grad) = meta_gradients.get_mut(&name) {
254                    *meta_grad = meta_grad.clone() + grad;
255                }
256            }
257        }
258
259        // Average gradients and loss
260        let n_tasks = tasks.len() as f64;
261        for grad in meta_gradients.values_mut() {
262            *grad = grad.mapv(|x| x / n_tasks);
263        }
264        total_loss /= n_tasks;
265
266        // Meta-update (SGD step)
267        let mut updated_params = HashMap::new();
268        for (name, param) in parameters {
269            if let Some(grad) = meta_gradients.get(name) {
270                let updated = param - &grad.mapv(|g| g * self.config.outer_lr);
271                updated_params.insert(name.clone(), updated);
272            }
273        }
274
275        Ok((updated_params, total_loss))
276    }
277
278    fn adapt(
279        &self,
280        task: &MetaTask,
281        parameters: &HashMap<String, Array1<f64>>,
282    ) -> TrainResult<HashMap<String, Array1<f64>>> {
283        let mut adapted_params = parameters.clone();
284
285        // Perform inner loop updates
286        for _ in 0..self.config.inner_steps {
287            // Compute loss and gradients on support set
288            let gradients = self.compute_support_gradients(task, &adapted_params)?;
289
290            // SGD update
291            for (name, param) in &mut adapted_params {
292                if let Some(grad) = gradients.get(name) {
293                    *param = param.clone() - &grad.mapv(|g| g * self.config.inner_lr);
294                }
295            }
296        }
297
298        Ok(adapted_params)
299    }
300}
301
302impl MAML {
303    /// Compute gradients on support set (placeholder).
304    fn compute_support_gradients(
305        &self,
306        task: &MetaTask,
307        _parameters: &HashMap<String, Array1<f64>>,
308    ) -> TrainResult<HashMap<String, Array1<f64>>> {
309        // This is a simplified placeholder
310        // In practice, you'd compute actual gradients through your model
311        let mut gradients = HashMap::new();
312        gradients.insert("weights".to_string(), Array1::zeros(task.support_x.ncols()));
313        Ok(gradients)
314    }
315
316    /// Compute loss on query set (placeholder).
317    fn compute_query_loss(
318        &self,
319        task: &MetaTask,
320        _parameters: &HashMap<String, Array1<f64>>,
321    ) -> TrainResult<f64> {
322        // This is a simplified placeholder
323        // In practice, you'd compute actual loss through your model
324        Ok(task.query_size() as f64 * 0.1)
325    }
326
327    /// Compute first-order gradients (placeholder).
328    fn compute_first_order_gradients(
329        &self,
330        task: &MetaTask,
331        _parameters: &HashMap<String, Array1<f64>>,
332    ) -> TrainResult<HashMap<String, Array1<f64>>> {
333        // This is a simplified placeholder
334        let mut gradients = HashMap::new();
335        gradients.insert("weights".to_string(), Array1::zeros(task.query_x.ncols()));
336        Ok(gradients)
337    }
338
339    /// Compute second-order gradients through adaptation (placeholder).
340    fn compute_second_order_gradients(
341        &self,
342        task: &MetaTask,
343        _meta_params: &HashMap<String, Array1<f64>>,
344        _adapted_params: &HashMap<String, Array1<f64>>,
345    ) -> TrainResult<HashMap<String, Array1<f64>>> {
346        // This is a simplified placeholder
347        // In full MAML, we'd backprop through the inner loop
348        let mut gradients = HashMap::new();
349        gradients.insert("weights".to_string(), Array1::zeros(task.query_x.ncols()));
350        Ok(gradients)
351    }
352}
353
354/// Reptile meta-learning algorithm.
355///
356/// Reptile is a simpler first-order algorithm that:
357/// 1. Samples a task
358/// 2. Trains on it via SGD to get φ
359/// 3. Updates θ ← θ + ε(φ - θ)
360#[derive(Debug, Clone)]
361pub struct Reptile {
362    config: ReptileConfig,
363}
364
365impl Reptile {
366    /// Create a new Reptile meta-learner.
367    pub fn new(config: ReptileConfig) -> Self {
368        Self { config }
369    }
370}
371
372impl Default for Reptile {
373    fn default() -> Self {
374        Self::new(ReptileConfig::default())
375    }
376}
377
378impl MetaLearner for Reptile {
379    fn meta_step(
380        &self,
381        tasks: &[MetaTask],
382        parameters: &HashMap<String, Array1<f64>>,
383    ) -> TrainResult<(HashMap<String, Array1<f64>>, f64)> {
384        let mut total_loss = 0.0;
385        let mut accumulated_delta: HashMap<String, Array1<f64>> = HashMap::new();
386
387        // Initialize accumulated delta to zero
388        for (name, param) in parameters {
389            accumulated_delta.insert(name.clone(), Array1::zeros(param.len()));
390        }
391
392        // For each task in the batch
393        for task in tasks {
394            // 1. Adapt to the task (train on support set)
395            let task_params = self.adapt(task, parameters)?;
396
397            // 2. Compute task loss (for monitoring)
398            let task_loss = self.compute_task_loss(task, &task_params)?;
399            total_loss += task_loss;
400
401            // 3. Compute direction: φ - θ
402            for (name, param) in parameters {
403                if let Some(task_param) = task_params.get(name) {
404                    let delta = task_param - param;
405                    if let Some(acc_delta) = accumulated_delta.get_mut(name) {
406                        *acc_delta = acc_delta.clone() + delta;
407                    }
408                }
409            }
410        }
411
412        // Average delta and loss
413        let n_tasks = tasks.len() as f64;
414        for delta in accumulated_delta.values_mut() {
415            *delta = delta.mapv(|x| x / n_tasks);
416        }
417        total_loss /= n_tasks;
418
419        // Meta-update: θ ← θ + ε * average_delta
420        let mut updated_params = HashMap::new();
421        for (name, param) in parameters {
422            if let Some(delta) = accumulated_delta.get(name) {
423                let updated = param + &delta.mapv(|d| d * self.config.outer_lr);
424                updated_params.insert(name.clone(), updated);
425            }
426        }
427
428        Ok((updated_params, total_loss))
429    }
430
431    fn adapt(
432        &self,
433        task: &MetaTask,
434        parameters: &HashMap<String, Array1<f64>>,
435    ) -> TrainResult<HashMap<String, Array1<f64>>> {
436        let mut task_params = parameters.clone();
437
438        // Perform SGD steps on support set
439        for _ in 0..self.config.inner_steps {
440            // Compute loss and gradients on support set
441            let gradients = self.compute_support_gradients(task, &task_params)?;
442
443            // SGD update
444            for (name, param) in &mut task_params {
445                if let Some(grad) = gradients.get(name) {
446                    *param = param.clone() - &grad.mapv(|g| g * self.config.inner_lr);
447                }
448            }
449        }
450
451        Ok(task_params)
452    }
453}
454
455impl Reptile {
456    /// Compute gradients on support set (placeholder).
457    fn compute_support_gradients(
458        &self,
459        task: &MetaTask,
460        _parameters: &HashMap<String, Array1<f64>>,
461    ) -> TrainResult<HashMap<String, Array1<f64>>> {
462        // This is a simplified placeholder
463        let mut gradients = HashMap::new();
464        gradients.insert("weights".to_string(), Array1::zeros(task.support_x.ncols()));
465        Ok(gradients)
466    }
467
468    /// Compute task loss (placeholder).
469    fn compute_task_loss(
470        &self,
471        task: &MetaTask,
472        _parameters: &HashMap<String, Array1<f64>>,
473    ) -> TrainResult<f64> {
474        // This is a simplified placeholder
475        Ok(task.query_size() as f64 * 0.1)
476    }
477}
478
479/// Meta-learning statistics tracker.
480#[derive(Debug, Clone, Default)]
481pub struct MetaStats {
482    /// Meta-training losses over time.
483    pub meta_losses: Vec<f64>,
484    /// Task adaptation losses.
485    pub task_losses: Vec<Vec<f64>>,
486    /// Number of meta-iterations completed.
487    pub iterations: usize,
488}
489
490impl MetaStats {
491    /// Create a new statistics tracker.
492    pub fn new() -> Self {
493        Self::default()
494    }
495
496    /// Record a meta-training step.
497    pub fn record_meta_step(&mut self, meta_loss: f64) {
498        self.meta_losses.push(meta_loss);
499        self.iterations += 1;
500    }
501
502    /// Record task adaptation.
503    pub fn record_task_adaptation(&mut self, task_id: usize, losses: Vec<f64>) {
504        while self.task_losses.len() <= task_id {
505            self.task_losses.push(Vec::new());
506        }
507        self.task_losses[task_id] = losses;
508    }
509
510    /// Get average meta-loss over last N steps.
511    pub fn avg_meta_loss(&self, last_n: usize) -> f64 {
512        if self.meta_losses.is_empty() {
513            return 0.0;
514        }
515
516        let n = last_n.min(self.meta_losses.len());
517        let start = self.meta_losses.len() - n;
518        self.meta_losses[start..].iter().sum::<f64>() / n as f64
519    }
520
521    /// Check if meta-training is improving (loss decreasing).
522    pub fn is_improving(&self, window: usize) -> bool {
523        if self.meta_losses.len() < window * 2 {
524            return false;
525        }
526
527        let recent = self.avg_meta_loss(window);
528        let previous = {
529            let start = self.meta_losses.len() - window * 2;
530            let end = self.meta_losses.len() - window;
531            self.meta_losses[start..end].iter().sum::<f64>() / window as f64
532        };
533
534        recent < previous
535    }
536}
537
538#[cfg(test)]
539mod tests {
540    use super::*;
541
542    #[test]
543    fn test_maml_config_default() {
544        let config = MAMLConfig::default();
545        assert_eq!(config.inner_steps, 5);
546        assert_eq!(config.inner_lr, 0.01);
547        assert_eq!(config.outer_lr, 0.001);
548        assert!(!config.first_order);
549    }
550
551    #[test]
552    fn test_reptile_config_default() {
553        let config = ReptileConfig::default();
554        assert_eq!(config.inner_steps, 10);
555        assert_eq!(config.inner_lr, 0.01);
556        assert_eq!(config.outer_lr, 0.1);
557    }
558
559    #[test]
560    fn test_meta_task_creation() {
561        let support_x = Array2::zeros((5, 10));
562        let support_y = Array2::zeros((5, 2));
563        let query_x = Array2::zeros((15, 10));
564        let query_y = Array2::zeros((15, 2));
565
566        let task = MetaTask::new(support_x, support_y, query_x, query_y).unwrap();
567        assert_eq!(task.support_size(), 5);
568        assert_eq!(task.query_size(), 15);
569    }
570
571    #[test]
572    fn test_meta_task_validation() {
573        let support_x = Array2::zeros((5, 10));
574        let support_y = Array2::zeros((4, 2)); // Mismatch!
575        let query_x = Array2::zeros((15, 10));
576        let query_y = Array2::zeros((15, 2));
577
578        let result = MetaTask::new(support_x, support_y, query_x, query_y);
579        assert!(result.is_err());
580    }
581
582    #[test]
583    fn test_maml_creation() {
584        let config = MAMLConfig::default();
585        let maml = MAML::new(config);
586        assert_eq!(maml.config.inner_steps, 5);
587    }
588
589    #[test]
590    fn test_maml_default() {
591        let maml = MAML::default();
592        assert_eq!(maml.config.inner_steps, 5);
593    }
594
595    #[test]
596    fn test_reptile_creation() {
597        let config = ReptileConfig::default();
598        let reptile = Reptile::new(config);
599        assert_eq!(reptile.config.inner_steps, 10);
600    }
601
602    #[test]
603    fn test_reptile_default() {
604        let reptile = Reptile::default();
605        assert_eq!(reptile.config.inner_steps, 10);
606    }
607
608    #[test]
609    fn test_maml_adapt() {
610        let maml = MAML::default();
611
612        let task = create_dummy_task();
613        let mut params = HashMap::new();
614        params.insert("weights".to_string(), Array1::zeros(10));
615
616        let adapted = maml.adapt(&task, &params).unwrap();
617        assert!(adapted.contains_key("weights"));
618    }
619
620    #[test]
621    fn test_reptile_adapt() {
622        let reptile = Reptile::default();
623
624        let task = create_dummy_task();
625        let mut params = HashMap::new();
626        params.insert("weights".to_string(), Array1::zeros(10));
627
628        let adapted = reptile.adapt(&task, &params).unwrap();
629        assert!(adapted.contains_key("weights"));
630    }
631
632    #[test]
633    fn test_maml_meta_step() {
634        let maml = MAML::default();
635
636        let tasks = vec![create_dummy_task(), create_dummy_task()];
637        let mut params = HashMap::new();
638        params.insert("weights".to_string(), Array1::zeros(10));
639
640        let (updated_params, loss) = maml.meta_step(&tasks, &params).unwrap();
641        assert!(updated_params.contains_key("weights"));
642        assert!(loss >= 0.0);
643    }
644
645    #[test]
646    fn test_reptile_meta_step() {
647        let reptile = Reptile::default();
648
649        let tasks = vec![create_dummy_task(), create_dummy_task()];
650        let mut params = HashMap::new();
651        params.insert("weights".to_string(), Array1::zeros(10));
652
653        let (updated_params, loss) = reptile.meta_step(&tasks, &params).unwrap();
654        assert!(updated_params.contains_key("weights"));
655        assert!(loss >= 0.0);
656    }
657
658    #[test]
659    fn test_meta_stats() {
660        let mut stats = MetaStats::new();
661
662        stats.record_meta_step(1.0);
663        stats.record_meta_step(0.8);
664        stats.record_meta_step(0.6);
665
666        assert_eq!(stats.iterations, 3);
667        assert_eq!(stats.meta_losses.len(), 3);
668        assert_eq!(stats.avg_meta_loss(2), 0.7);
669    }
670
671    #[test]
672    fn test_meta_stats_improvement() {
673        let mut stats = MetaStats::new();
674
675        // Add decreasing losses
676        for i in 0..20 {
677            stats.record_meta_step(1.0 - i as f64 * 0.01);
678        }
679
680        assert!(stats.is_improving(5));
681    }
682
683    #[test]
684    fn test_meta_stats_task_adaptation() {
685        let mut stats = MetaStats::new();
686
687        stats.record_task_adaptation(0, vec![1.0, 0.8, 0.6]);
688        stats.record_task_adaptation(1, vec![1.2, 0.9, 0.7]);
689
690        assert_eq!(stats.task_losses.len(), 2);
691        assert_eq!(stats.task_losses[0].len(), 3);
692    }
693
694    // Helper function
695    fn create_dummy_task() -> MetaTask {
696        let support_x = Array2::zeros((5, 10));
697        let support_y = Array2::zeros((5, 2));
698        let query_x = Array2::zeros((15, 10));
699        let query_y = Array2::zeros((15, 2));
700        MetaTask::new(support_x, support_y, query_x, query_y).unwrap()
701    }
702}