sklears_multioutput/regularization/
meta_learning.rs

1//! Meta-Learning for Multi-Task Learning
2//!
3//! This method learns meta-parameters that can quickly adapt to new tasks.
4//! It uses a model-agnostic meta-learning (MAML) approach adapted for multi-task scenarios.
5
6// Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
7use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
8use scirs2_core::random::thread_rng;
9use scirs2_core::random::RandNormal;
10use sklears_core::{
11    error::{Result as SklResult, SklearsError},
12    traits::{Estimator, Fit, Predict, Untrained},
13    types::Float,
14};
15use std::collections::HashMap;
16
17/// Meta-Learning for Multi-Task Learning
18///
19/// This method learns meta-parameters that can quickly adapt to new tasks.
20/// It uses a model-agnostic meta-learning (MAML) approach adapted for multi-task scenarios.
21///
22/// # Examples
23///
24/// ```
25/// use sklears_multioutput::regularization::MetaLearningMultiTask;
26/// use sklears_core::traits::{Predict, Fit};
27/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
28/// use scirs2_core::ndarray::array;
29/// use std::collections::HashMap;
30///
31/// let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
32/// let mut y_tasks = HashMap::new();
33/// y_tasks.insert("task1".to_string(), array![[1.0], [2.0], [1.5], [2.5]]);
34/// y_tasks.insert("task2".to_string(), array![[0.5], [1.0], [0.8], [1.2]]);
35///
36/// let meta_learning = MetaLearningMultiTask::new()
37///     .meta_learning_rate(0.01)
38///     .inner_learning_rate(0.1)
39///     .n_inner_steps(5)
40///     .max_iter(1000);
41/// ```
42#[derive(Debug, Clone)]
43pub struct MetaLearningMultiTask<S = Untrained> {
44    pub(crate) state: S,
45    /// Meta-learning rate for updating meta-parameters
46    pub(crate) meta_learning_rate: Float,
47    /// Inner learning rate for task-specific adaptation
48    pub(crate) inner_learning_rate: Float,
49    /// Number of inner gradient steps per task
50    pub(crate) n_inner_steps: usize,
51    /// Maximum meta-iterations
52    pub(crate) max_iter: usize,
53    /// Convergence tolerance
54    pub(crate) tolerance: Float,
55    /// Task configurations
56    pub(crate) task_outputs: HashMap<String, usize>,
57    /// Include intercept term
58    pub(crate) fit_intercept: bool,
59    /// Random state for reproducible meta-learning
60    pub(crate) random_state: Option<u64>,
61}
62
63/// Trained state for MetaLearningMultiTask
64#[derive(Debug, Clone)]
65pub struct MetaLearningMultiTaskTrained {
66    /// Meta-parameters (initialization for new tasks)
67    pub(crate) meta_parameters: Array2<Float>,
68    /// Meta-intercepts
69    pub(crate) meta_intercepts: Array1<Float>,
70    /// Task-specific adapted parameters
71    pub(crate) task_parameters: HashMap<String, Array2<Float>>,
72    /// Task-specific adapted intercepts
73    pub(crate) task_intercepts: HashMap<String, Array1<Float>>,
74    /// Number of input features
75    pub(crate) n_features: usize,
76    /// Task configurations
77    pub(crate) task_outputs: HashMap<String, usize>,
78    /// Training parameters
79    pub(crate) meta_learning_rate: Float,
80    pub(crate) inner_learning_rate: Float,
81    pub(crate) n_inner_steps: usize,
82    /// Training iterations performed
83    pub(crate) n_iter: usize,
84}
85
86impl MetaLearningMultiTask<Untrained> {
87    /// Create a new MetaLearningMultiTask instance
88    pub fn new() -> Self {
89        Self {
90            state: Untrained,
91            meta_learning_rate: 0.01,
92            inner_learning_rate: 0.1,
93            n_inner_steps: 5,
94            max_iter: 1000,
95            tolerance: 1e-4,
96            task_outputs: HashMap::new(),
97            fit_intercept: true,
98            random_state: None,
99        }
100    }
101
102    /// Set meta-learning rate
103    pub fn meta_learning_rate(mut self, lr: Float) -> Self {
104        self.meta_learning_rate = lr;
105        self
106    }
107
108    /// Set inner learning rate
109    pub fn inner_learning_rate(mut self, lr: Float) -> Self {
110        self.inner_learning_rate = lr;
111        self
112    }
113
114    /// Set number of inner gradient steps
115    pub fn n_inner_steps(mut self, steps: usize) -> Self {
116        self.n_inner_steps = steps;
117        self
118    }
119
120    /// Set maximum iterations
121    pub fn max_iter(mut self, max_iter: usize) -> Self {
122        self.max_iter = max_iter;
123        self
124    }
125
126    /// Set tolerance
127    pub fn tolerance(mut self, tolerance: Float) -> Self {
128        self.tolerance = tolerance;
129        self
130    }
131
132    /// Set random state
133    pub fn random_state(mut self, seed: u64) -> Self {
134        self.random_state = Some(seed);
135        self
136    }
137
138    /// Set task outputs
139    pub fn task_outputs(mut self, outputs: &[(&str, usize)]) -> Self {
140        self.task_outputs = outputs
141            .iter()
142            .map(|(name, size)| (name.to_string(), *size))
143            .collect();
144        self
145    }
146}
147
148impl Default for MetaLearningMultiTask<Untrained> {
149    fn default() -> Self {
150        Self::new()
151    }
152}
153
154impl Estimator for MetaLearningMultiTask<Untrained> {
155    type Config = ();
156    type Error = SklearsError;
157    type Float = Float;
158
159    fn config(&self) -> &Self::Config {
160        &()
161    }
162}
163
164impl Fit<ArrayView2<'_, Float>, HashMap<String, Array2<Float>>>
165    for MetaLearningMultiTask<Untrained>
166{
167    type Fitted = MetaLearningMultiTask<MetaLearningMultiTaskTrained>;
168
169    fn fit(
170        self,
171        X: &ArrayView2<'_, Float>,
172        y: &HashMap<String, Array2<Float>>,
173    ) -> SklResult<Self::Fitted> {
174        let x = X.to_owned();
175        let (n_samples, n_features) = x.dim();
176
177        if n_samples == 0 || n_features == 0 {
178            return Err(SklearsError::InvalidInput("Empty input data".to_string()));
179        }
180
181        // Initialize meta-parameters
182        let mut rng_gen = thread_rng();
183
184        // Use first task to determine output size for meta-parameters
185        let first_task_outputs = y.values().next().unwrap().ncols();
186        let mut meta_parameters = Array2::<Float>::zeros((n_features, first_task_outputs));
187        let normal_dist = RandNormal::new(0.0, 0.1).unwrap();
188        for i in 0..n_features {
189            for j in 0..first_task_outputs {
190                meta_parameters[[i, j]] = rng_gen.sample(normal_dist);
191            }
192        }
193        let mut meta_intercepts = Array1::<Float>::zeros(first_task_outputs);
194
195        let task_names: Vec<String> = y.keys().cloned().collect();
196        let mut task_parameters: HashMap<String, Array2<Float>> = HashMap::new();
197        let mut task_intercepts: HashMap<String, Array1<Float>> = HashMap::new();
198
199        // Meta-learning loop
200        let mut prev_loss = Float::INFINITY;
201        let mut n_iter = 0;
202
203        for iteration in 0..self.max_iter {
204            let mut total_meta_loss = 0.0;
205            let mut meta_grad_sum: Array2<Float> = Array2::<Float>::zeros(meta_parameters.dim());
206            let mut meta_intercept_grad_sum: Array1<Float> =
207                Array1::<Float>::zeros(meta_intercepts.len());
208
209            // For each task, perform inner loop adaptation
210            for (task_name, y_task) in y {
211                // Initialize task parameters from meta-parameters
212                let mut task_params = meta_parameters.clone();
213                let mut task_intercept = meta_intercepts.clone();
214
215                // Inner loop: adapt to specific task
216                for _inner_step in 0..self.n_inner_steps {
217                    // Compute predictions
218                    let predictions = x.dot(&task_params);
219                    let predictions_with_intercept = &predictions + &task_intercept;
220
221                    // Compute residuals
222                    let residuals = &predictions_with_intercept - y_task;
223
224                    // Compute gradients
225                    let grad_params = x.t().dot(&residuals) / (n_samples as Float);
226                    let grad_intercept = residuals.sum_axis(Axis(0)) / (n_samples as Float);
227
228                    // Update task-specific parameters
229                    task_params -= &(&grad_params * self.inner_learning_rate);
230                    task_intercept -= &(&grad_intercept * self.inner_learning_rate);
231                }
232
233                // Compute final loss for this task
234                let final_predictions = x.dot(&task_params);
235                let final_predictions_with_intercept = &final_predictions + &task_intercept;
236                let final_residuals = &final_predictions_with_intercept - y_task;
237                let task_loss = final_residuals.mapv(|x| x * x).sum();
238                total_meta_loss += task_loss;
239
240                // Compute meta-gradients (how changes in meta-parameters affect final loss)
241                let meta_grad_params = x.t().dot(&final_residuals) / (n_samples as Float);
242                let meta_grad_intercept = final_residuals.sum_axis(Axis(0)) / (n_samples as Float);
243
244                meta_grad_sum = meta_grad_sum + meta_grad_params;
245                meta_intercept_grad_sum = meta_intercept_grad_sum + meta_grad_intercept;
246
247                // Store adapted parameters
248                task_parameters.insert(task_name.clone(), task_params);
249                task_intercepts.insert(task_name.clone(), task_intercept);
250            }
251
252            // Update meta-parameters
253            let n_tasks = y.len() as Float;
254            meta_parameters -= &(&(meta_grad_sum / n_tasks) * self.meta_learning_rate);
255            meta_intercepts -= &(&(meta_intercept_grad_sum / n_tasks) * self.meta_learning_rate);
256
257            // Check convergence
258            if (prev_loss - total_meta_loss).abs() < self.tolerance {
259                n_iter = iteration + 1;
260                break;
261            }
262            prev_loss = total_meta_loss;
263            n_iter = iteration + 1;
264        }
265
266        Ok(MetaLearningMultiTask {
267            state: MetaLearningMultiTaskTrained {
268                meta_parameters,
269                meta_intercepts,
270                task_parameters,
271                task_intercepts,
272                n_features,
273                task_outputs: self.task_outputs.clone(),
274                meta_learning_rate: self.meta_learning_rate,
275                inner_learning_rate: self.inner_learning_rate,
276                n_inner_steps: self.n_inner_steps,
277                n_iter,
278            },
279            meta_learning_rate: self.meta_learning_rate,
280            inner_learning_rate: self.inner_learning_rate,
281            n_inner_steps: self.n_inner_steps,
282            max_iter: self.max_iter,
283            tolerance: self.tolerance,
284            task_outputs: self.task_outputs,
285            fit_intercept: self.fit_intercept,
286            random_state: self.random_state,
287        })
288    }
289}
290
291impl Predict<ArrayView2<'_, Float>, HashMap<String, Array2<Float>>>
292    for MetaLearningMultiTask<MetaLearningMultiTaskTrained>
293{
294    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<HashMap<String, Array2<Float>>> {
295        let x = X.to_owned();
296        let (n_samples, n_features) = x.dim();
297
298        if n_features != self.state.n_features {
299            return Err(SklearsError::InvalidInput(
300                "Number of features doesn't match training data".to_string(),
301            ));
302        }
303
304        let mut predictions = HashMap::new();
305
306        for (task_name, coef) in &self.state.task_parameters {
307            let task_predictions = x.dot(coef);
308            let intercept = &self.state.task_intercepts[task_name];
309            let final_predictions = &task_predictions + intercept;
310            predictions.insert(task_name.clone(), final_predictions);
311        }
312
313        Ok(predictions)
314    }
315}
316
317impl MetaLearningMultiTask<MetaLearningMultiTaskTrained> {
318    /// Adapt meta-parameters to a new task with few examples
319    pub fn adapt_to_new_task(
320        &self,
321        X: &ArrayView2<Float>,
322        y: &Array2<Float>,
323        n_adaptation_steps: usize,
324    ) -> SklResult<(Array2<Float>, Array1<Float>)> {
325        let x = X.to_owned();
326        let (n_samples, n_features) = x.dim();
327
328        if n_features != self.state.n_features {
329            return Err(SklearsError::InvalidInput(
330                "Number of features doesn't match training data".to_string(),
331            ));
332        }
333
334        // Start with meta-parameters
335        let mut adapted_params = self.state.meta_parameters.clone();
336        let mut adapted_intercept = self.state.meta_intercepts.clone();
337
338        // Perform adaptation steps
339        for _step in 0..n_adaptation_steps {
340            // Compute predictions
341            let predictions = x.dot(&adapted_params);
342            let predictions_with_intercept = &predictions + &adapted_intercept;
343
344            // Compute residuals
345            let residuals = &predictions_with_intercept - y;
346
347            // Compute gradients
348            let grad_params = x.t().dot(&residuals) / (n_samples as Float);
349            let grad_intercept = residuals.sum_axis(Axis(0)) / (n_samples as Float);
350
351            // Update parameters
352            adapted_params -= &(&grad_params * self.state.inner_learning_rate);
353            adapted_intercept -= &(&grad_intercept * self.state.inner_learning_rate);
354        }
355
356        Ok((adapted_params, adapted_intercept))
357    }
358
359    /// Get meta-parameters for initialization of new tasks
360    pub fn get_meta_parameters(&self) -> (&Array2<Float>, &Array1<Float>) {
361        (&self.state.meta_parameters, &self.state.meta_intercepts)
362    }
363}
364
365impl MetaLearningMultiTaskTrained {
366    /// Get meta-parameters
367    pub fn meta_parameters(&self) -> &Array2<Float> {
368        &self.meta_parameters
369    }
370
371    /// Get meta-intercepts
372    pub fn meta_intercepts(&self) -> &Array1<Float> {
373        &self.meta_intercepts
374    }
375
376    /// Get task-specific parameters
377    pub fn task_parameters(&self, task_name: &str) -> Option<&Array2<Float>> {
378        self.task_parameters.get(task_name)
379    }
380
381    /// Get task-specific intercepts
382    pub fn task_intercepts(&self, task_name: &str) -> Option<&Array1<Float>> {
383        self.task_intercepts.get(task_name)
384    }
385
386    /// Get number of iterations performed
387    pub fn n_iter(&self) -> usize {
388        self.n_iter
389    }
390
391    /// Get meta-learning parameters
392    pub fn meta_learning_config(&self) -> (Float, Float, usize) {
393        (
394            self.meta_learning_rate,
395            self.inner_learning_rate,
396            self.n_inner_steps,
397        )
398    }
399}