sklears_multioutput/regularization/
meta_learning.rs1use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
8use scirs2_core::random::thread_rng;
9use scirs2_core::random::RandNormal;
10use sklears_core::{
11 error::{Result as SklResult, SklearsError},
12 traits::{Estimator, Fit, Predict, Untrained},
13 types::Float,
14};
15use std::collections::HashMap;
16
17#[derive(Debug, Clone)]
43pub struct MetaLearningMultiTask<S = Untrained> {
44 pub(crate) state: S,
45 pub(crate) meta_learning_rate: Float,
47 pub(crate) inner_learning_rate: Float,
49 pub(crate) n_inner_steps: usize,
51 pub(crate) max_iter: usize,
53 pub(crate) tolerance: Float,
55 pub(crate) task_outputs: HashMap<String, usize>,
57 pub(crate) fit_intercept: bool,
59 pub(crate) random_state: Option<u64>,
61}
62
63#[derive(Debug, Clone)]
65pub struct MetaLearningMultiTaskTrained {
66 pub(crate) meta_parameters: Array2<Float>,
68 pub(crate) meta_intercepts: Array1<Float>,
70 pub(crate) task_parameters: HashMap<String, Array2<Float>>,
72 pub(crate) task_intercepts: HashMap<String, Array1<Float>>,
74 pub(crate) n_features: usize,
76 pub(crate) task_outputs: HashMap<String, usize>,
78 pub(crate) meta_learning_rate: Float,
80 pub(crate) inner_learning_rate: Float,
81 pub(crate) n_inner_steps: usize,
82 pub(crate) n_iter: usize,
84}
85
86impl MetaLearningMultiTask<Untrained> {
87 pub fn new() -> Self {
89 Self {
90 state: Untrained,
91 meta_learning_rate: 0.01,
92 inner_learning_rate: 0.1,
93 n_inner_steps: 5,
94 max_iter: 1000,
95 tolerance: 1e-4,
96 task_outputs: HashMap::new(),
97 fit_intercept: true,
98 random_state: None,
99 }
100 }
101
102 pub fn meta_learning_rate(mut self, lr: Float) -> Self {
104 self.meta_learning_rate = lr;
105 self
106 }
107
108 pub fn inner_learning_rate(mut self, lr: Float) -> Self {
110 self.inner_learning_rate = lr;
111 self
112 }
113
114 pub fn n_inner_steps(mut self, steps: usize) -> Self {
116 self.n_inner_steps = steps;
117 self
118 }
119
120 pub fn max_iter(mut self, max_iter: usize) -> Self {
122 self.max_iter = max_iter;
123 self
124 }
125
126 pub fn tolerance(mut self, tolerance: Float) -> Self {
128 self.tolerance = tolerance;
129 self
130 }
131
132 pub fn random_state(mut self, seed: u64) -> Self {
134 self.random_state = Some(seed);
135 self
136 }
137
138 pub fn task_outputs(mut self, outputs: &[(&str, usize)]) -> Self {
140 self.task_outputs = outputs
141 .iter()
142 .map(|(name, size)| (name.to_string(), *size))
143 .collect();
144 self
145 }
146}
147
148impl Default for MetaLearningMultiTask<Untrained> {
149 fn default() -> Self {
150 Self::new()
151 }
152}
153
154impl Estimator for MetaLearningMultiTask<Untrained> {
155 type Config = ();
156 type Error = SklearsError;
157 type Float = Float;
158
159 fn config(&self) -> &Self::Config {
160 &()
161 }
162}
163
164impl Fit<ArrayView2<'_, Float>, HashMap<String, Array2<Float>>>
165 for MetaLearningMultiTask<Untrained>
166{
167 type Fitted = MetaLearningMultiTask<MetaLearningMultiTaskTrained>;
168
169 fn fit(
170 self,
171 X: &ArrayView2<'_, Float>,
172 y: &HashMap<String, Array2<Float>>,
173 ) -> SklResult<Self::Fitted> {
174 let x = X.to_owned();
175 let (n_samples, n_features) = x.dim();
176
177 if n_samples == 0 || n_features == 0 {
178 return Err(SklearsError::InvalidInput("Empty input data".to_string()));
179 }
180
181 let mut rng_gen = thread_rng();
183
184 let first_task_outputs = y.values().next().unwrap().ncols();
186 let mut meta_parameters = Array2::<Float>::zeros((n_features, first_task_outputs));
187 let normal_dist = RandNormal::new(0.0, 0.1).unwrap();
188 for i in 0..n_features {
189 for j in 0..first_task_outputs {
190 meta_parameters[[i, j]] = rng_gen.sample(normal_dist);
191 }
192 }
193 let mut meta_intercepts = Array1::<Float>::zeros(first_task_outputs);
194
195 let task_names: Vec<String> = y.keys().cloned().collect();
196 let mut task_parameters: HashMap<String, Array2<Float>> = HashMap::new();
197 let mut task_intercepts: HashMap<String, Array1<Float>> = HashMap::new();
198
199 let mut prev_loss = Float::INFINITY;
201 let mut n_iter = 0;
202
203 for iteration in 0..self.max_iter {
204 let mut total_meta_loss = 0.0;
205 let mut meta_grad_sum: Array2<Float> = Array2::<Float>::zeros(meta_parameters.dim());
206 let mut meta_intercept_grad_sum: Array1<Float> =
207 Array1::<Float>::zeros(meta_intercepts.len());
208
209 for (task_name, y_task) in y {
211 let mut task_params = meta_parameters.clone();
213 let mut task_intercept = meta_intercepts.clone();
214
215 for _inner_step in 0..self.n_inner_steps {
217 let predictions = x.dot(&task_params);
219 let predictions_with_intercept = &predictions + &task_intercept;
220
221 let residuals = &predictions_with_intercept - y_task;
223
224 let grad_params = x.t().dot(&residuals) / (n_samples as Float);
226 let grad_intercept = residuals.sum_axis(Axis(0)) / (n_samples as Float);
227
228 task_params -= &(&grad_params * self.inner_learning_rate);
230 task_intercept -= &(&grad_intercept * self.inner_learning_rate);
231 }
232
233 let final_predictions = x.dot(&task_params);
235 let final_predictions_with_intercept = &final_predictions + &task_intercept;
236 let final_residuals = &final_predictions_with_intercept - y_task;
237 let task_loss = final_residuals.mapv(|x| x * x).sum();
238 total_meta_loss += task_loss;
239
240 let meta_grad_params = x.t().dot(&final_residuals) / (n_samples as Float);
242 let meta_grad_intercept = final_residuals.sum_axis(Axis(0)) / (n_samples as Float);
243
244 meta_grad_sum = meta_grad_sum + meta_grad_params;
245 meta_intercept_grad_sum = meta_intercept_grad_sum + meta_grad_intercept;
246
247 task_parameters.insert(task_name.clone(), task_params);
249 task_intercepts.insert(task_name.clone(), task_intercept);
250 }
251
252 let n_tasks = y.len() as Float;
254 meta_parameters -= &(&(meta_grad_sum / n_tasks) * self.meta_learning_rate);
255 meta_intercepts -= &(&(meta_intercept_grad_sum / n_tasks) * self.meta_learning_rate);
256
257 if (prev_loss - total_meta_loss).abs() < self.tolerance {
259 n_iter = iteration + 1;
260 break;
261 }
262 prev_loss = total_meta_loss;
263 n_iter = iteration + 1;
264 }
265
266 Ok(MetaLearningMultiTask {
267 state: MetaLearningMultiTaskTrained {
268 meta_parameters,
269 meta_intercepts,
270 task_parameters,
271 task_intercepts,
272 n_features,
273 task_outputs: self.task_outputs.clone(),
274 meta_learning_rate: self.meta_learning_rate,
275 inner_learning_rate: self.inner_learning_rate,
276 n_inner_steps: self.n_inner_steps,
277 n_iter,
278 },
279 meta_learning_rate: self.meta_learning_rate,
280 inner_learning_rate: self.inner_learning_rate,
281 n_inner_steps: self.n_inner_steps,
282 max_iter: self.max_iter,
283 tolerance: self.tolerance,
284 task_outputs: self.task_outputs,
285 fit_intercept: self.fit_intercept,
286 random_state: self.random_state,
287 })
288 }
289}
290
291impl Predict<ArrayView2<'_, Float>, HashMap<String, Array2<Float>>>
292 for MetaLearningMultiTask<MetaLearningMultiTaskTrained>
293{
294 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<HashMap<String, Array2<Float>>> {
295 let x = X.to_owned();
296 let (n_samples, n_features) = x.dim();
297
298 if n_features != self.state.n_features {
299 return Err(SklearsError::InvalidInput(
300 "Number of features doesn't match training data".to_string(),
301 ));
302 }
303
304 let mut predictions = HashMap::new();
305
306 for (task_name, coef) in &self.state.task_parameters {
307 let task_predictions = x.dot(coef);
308 let intercept = &self.state.task_intercepts[task_name];
309 let final_predictions = &task_predictions + intercept;
310 predictions.insert(task_name.clone(), final_predictions);
311 }
312
313 Ok(predictions)
314 }
315}
316
317impl MetaLearningMultiTask<MetaLearningMultiTaskTrained> {
318 pub fn adapt_to_new_task(
320 &self,
321 X: &ArrayView2<Float>,
322 y: &Array2<Float>,
323 n_adaptation_steps: usize,
324 ) -> SklResult<(Array2<Float>, Array1<Float>)> {
325 let x = X.to_owned();
326 let (n_samples, n_features) = x.dim();
327
328 if n_features != self.state.n_features {
329 return Err(SklearsError::InvalidInput(
330 "Number of features doesn't match training data".to_string(),
331 ));
332 }
333
334 let mut adapted_params = self.state.meta_parameters.clone();
336 let mut adapted_intercept = self.state.meta_intercepts.clone();
337
338 for _step in 0..n_adaptation_steps {
340 let predictions = x.dot(&adapted_params);
342 let predictions_with_intercept = &predictions + &adapted_intercept;
343
344 let residuals = &predictions_with_intercept - y;
346
347 let grad_params = x.t().dot(&residuals) / (n_samples as Float);
349 let grad_intercept = residuals.sum_axis(Axis(0)) / (n_samples as Float);
350
351 adapted_params -= &(&grad_params * self.state.inner_learning_rate);
353 adapted_intercept -= &(&grad_intercept * self.state.inner_learning_rate);
354 }
355
356 Ok((adapted_params, adapted_intercept))
357 }
358
359 pub fn get_meta_parameters(&self) -> (&Array2<Float>, &Array1<Float>) {
361 (&self.state.meta_parameters, &self.state.meta_intercepts)
362 }
363}
364
365impl MetaLearningMultiTaskTrained {
366 pub fn meta_parameters(&self) -> &Array2<Float> {
368 &self.meta_parameters
369 }
370
371 pub fn meta_intercepts(&self) -> &Array1<Float> {
373 &self.meta_intercepts
374 }
375
376 pub fn task_parameters(&self, task_name: &str) -> Option<&Array2<Float>> {
378 self.task_parameters.get(task_name)
379 }
380
381 pub fn task_intercepts(&self, task_name: &str) -> Option<&Array1<Float>> {
383 self.task_intercepts.get(task_name)
384 }
385
386 pub fn n_iter(&self) -> usize {
388 self.n_iter
389 }
390
391 pub fn meta_learning_config(&self) -> (Float, Float, usize) {
393 (
394 self.meta_learning_rate,
395 self.inner_learning_rate,
396 self.n_inner_steps,
397 )
398 }
399}