1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
7use sklears_core::{
8 error::Result as SklResult,
9 prelude::{Predict, SklearsError},
10 traits::{Estimator, Fit, Untrained},
11 types::Float,
12};
13use std::collections::{HashMap, VecDeque};
14
15use crate::{PipelinePredictor, PipelineStep};
16
17#[derive(Debug, Clone)]
19pub struct Experience {
20 pub task_id: String,
22 pub features: Array2<f64>,
24 pub targets: Array1<f64>,
26 pub metadata: HashMap<String, String>,
28 pub performance: HashMap<String, f64>,
30 pub parameters: HashMap<String, f64>,
32}
33
34impl Experience {
35 #[must_use]
37 pub fn new(task_id: String, features: Array2<f64>, targets: Array1<f64>) -> Self {
38 Self {
39 task_id,
40 features,
41 targets,
42 metadata: HashMap::new(),
43 performance: HashMap::new(),
44 parameters: HashMap::new(),
45 }
46 }
47
48 #[must_use]
50 pub fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
51 self.metadata = metadata;
52 self
53 }
54
55 #[must_use]
57 pub fn with_performance(mut self, performance: HashMap<String, f64>) -> Self {
58 self.performance = performance;
59 self
60 }
61
62 #[must_use]
64 pub fn with_parameters(mut self, parameters: HashMap<String, f64>) -> Self {
65 self.parameters = parameters;
66 self
67 }
68}
69
70#[derive(Debug, Clone)]
72pub struct ExperienceStorage {
73 max_size: usize,
75 experiences: VecDeque<Experience>,
77 task_index: HashMap<String, Vec<usize>>,
79}
80
81impl ExperienceStorage {
82 #[must_use]
84 pub fn new(max_size: usize) -> Self {
85 Self {
86 max_size,
87 experiences: VecDeque::new(),
88 task_index: HashMap::new(),
89 }
90 }
91
92 pub fn add_experience(&mut self, experience: Experience) {
94 let task_id = experience.task_id.clone();
95
96 if self.experiences.len() >= self.max_size {
98 if let Some(removed) = self.experiences.pop_front() {
99 self.remove_from_index(&removed.task_id, 0);
100 }
101 }
102
103 let index = self.experiences.len();
105 self.experiences.push_back(experience);
106
107 self.task_index.entry(task_id).or_default().push(index);
109 }
110
111 #[must_use]
113 pub fn get_task_experiences(&self, task_id: &str) -> Vec<&Experience> {
114 if let Some(indices) = self.task_index.get(task_id) {
115 indices
116 .iter()
117 .filter_map(|&i| self.experiences.get(i))
118 .collect()
119 } else {
120 Vec::new()
121 }
122 }
123
124 #[must_use]
126 pub fn get_all_experiences(&self) -> Vec<&Experience> {
127 self.experiences.iter().collect()
128 }
129
130 #[must_use]
132 pub fn get_similar_experiences(
133 &self,
134 features: &ArrayView2<'_, f64>,
135 k: usize,
136 ) -> Vec<&Experience> {
137 let mut similarities = Vec::new();
138
139 for exp in &self.experiences {
140 let similarity = self.compute_similarity(features, &exp.features.view());
141 similarities.push((similarity, exp));
142 }
143
144 similarities.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
146
147 similarities
149 .into_iter()
150 .take(k)
151 .map(|(_, exp)| exp)
152 .collect()
153 }
154
155 fn compute_similarity(
157 &self,
158 features1: &ArrayView2<'_, f64>,
159 features2: &ArrayView2<'_, f64>,
160 ) -> f64 {
161 if features1.ncols() != features2.ncols() {
162 return 0.0;
163 }
164
165 let mean1 = features1.mean_axis(Axis(0)).unwrap();
167 let mean2 = features2.mean_axis(Axis(0)).unwrap();
168
169 let dot_product = mean1.dot(&mean2);
171 let norm1 = mean1.mapv(|x| x * x).sum().sqrt();
172 let norm2 = mean2.mapv(|x| x * x).sum().sqrt();
173
174 if norm1 == 0.0 || norm2 == 0.0 {
175 0.0
176 } else {
177 dot_product / (norm1 * norm2)
178 }
179 }
180
181 fn remove_from_index(&mut self, task_id: &str, index: usize) {
183 if let Some(indices) = self.task_index.get_mut(task_id) {
184 indices.retain(|&i| i != index);
185 if indices.is_empty() {
186 self.task_index.remove(task_id);
187 }
188 }
189 }
190
191 #[must_use]
193 pub fn len(&self) -> usize {
194 self.experiences.len()
195 }
196
197 #[must_use]
199 pub fn is_empty(&self) -> bool {
200 self.experiences.is_empty()
201 }
202}
203
204#[derive(Debug, Clone)]
206pub enum AdaptationStrategy {
207 FineTuning {
209 learning_rate: f64,
210 num_steps: usize,
211 },
212 LastLayerFineTuning {
214 learning_rate: f64,
215 num_steps: usize,
216 },
217 FeatureAdaptation { adaptation_weight: f64 },
219 ParameterAveraging {
221 num_similar: usize,
222 similarity_threshold: f64,
223 },
224 GradientBased {
226 inner_lr: f64,
227 outer_lr: f64,
228 num_inner_steps: usize,
229 },
230}
231
232impl AdaptationStrategy {
233 pub fn adapt(
235 &self,
236 current_params: &HashMap<String, f64>,
237 experiences: &[&Experience],
238 task_features: &ArrayView2<'_, f64>,
239 ) -> SklResult<HashMap<String, f64>> {
240 match self {
241 AdaptationStrategy::FineTuning {
242 learning_rate,
243 num_steps,
244 } => self.fine_tune_adaptation(current_params, experiences, *learning_rate, *num_steps),
245 AdaptationStrategy::LastLayerFineTuning {
246 learning_rate,
247 num_steps,
248 } => {
249 self.last_layer_adaptation(current_params, experiences, *learning_rate, *num_steps)
250 }
251 AdaptationStrategy::FeatureAdaptation { adaptation_weight } => self.feature_adaptation(
252 current_params,
253 experiences,
254 task_features,
255 *adaptation_weight,
256 ),
257 AdaptationStrategy::ParameterAveraging {
258 num_similar,
259 similarity_threshold,
260 } => self.parameter_averaging(
261 current_params,
262 experiences,
263 *num_similar,
264 *similarity_threshold,
265 ),
266 AdaptationStrategy::GradientBased {
267 inner_lr,
268 outer_lr,
269 num_inner_steps,
270 } => self.gradient_based_adaptation(
271 current_params,
272 experiences,
273 *inner_lr,
274 *outer_lr,
275 *num_inner_steps,
276 ),
277 }
278 }
279
280 fn fine_tune_adaptation(
282 &self,
283 current_params: &HashMap<String, f64>,
284 experiences: &[&Experience],
285 learning_rate: f64,
286 num_steps: usize,
287 ) -> SklResult<HashMap<String, f64>> {
288 let mut adapted_params = current_params.clone();
289
290 if experiences.is_empty() {
291 return Ok(adapted_params);
292 }
293
294 for _ in 0..num_steps {
296 for (key, value) in &mut adapted_params {
297 let mut gradient = 0.0;
299 let mut count = 0;
300
301 for exp in experiences {
302 if let Some(&exp_param) = exp.parameters.get(key) {
303 if let Some(&performance) = exp.performance.get("accuracy") {
304 gradient += (exp_param - *value) * performance;
306 count += 1;
307 }
308 }
309 }
310
311 if count > 0 {
312 gradient /= f64::from(count);
313 *value += learning_rate * gradient;
314 }
315 }
316 }
317
318 Ok(adapted_params)
319 }
320
321 fn last_layer_adaptation(
323 &self,
324 current_params: &HashMap<String, f64>,
325 experiences: &[&Experience],
326 learning_rate: f64,
327 num_steps: usize,
328 ) -> SklResult<HashMap<String, f64>> {
329 let mut adapted_params = current_params.clone();
330
331 let last_layer_keys: Vec<String> = adapted_params
333 .keys()
334 .filter(|key| key.contains("output") || key.contains("final"))
335 .cloned()
336 .collect();
337
338 for _ in 0..num_steps {
339 for key in &last_layer_keys {
340 if let Some(value) = adapted_params.get_mut(key) {
341 let mut gradient = 0.0;
342 let mut count = 0;
343
344 for exp in experiences {
345 if let Some(&exp_param) = exp.parameters.get(key) {
346 if let Some(&performance) = exp.performance.get("accuracy") {
347 gradient += (exp_param - *value) * performance;
348 count += 1;
349 }
350 }
351 }
352
353 if count > 0 {
354 gradient /= f64::from(count);
355 *value += learning_rate * gradient;
356 }
357 }
358 }
359 }
360
361 Ok(adapted_params)
362 }
363
364 fn feature_adaptation(
366 &self,
367 current_params: &HashMap<String, f64>,
368 experiences: &[&Experience],
369 _task_features: &ArrayView2<'_, f64>,
370 adaptation_weight: f64,
371 ) -> SklResult<HashMap<String, f64>> {
372 let mut adapted_params = current_params.clone();
373
374 if experiences.is_empty() {
375 return Ok(adapted_params);
376 }
377
378 for (key, value) in &mut adapted_params {
380 let mut weighted_sum = 0.0;
381 let mut weight_sum = 0.0;
382
383 for exp in experiences {
384 if let Some(&exp_param) = exp.parameters.get(key) {
385 if let Some(&performance) = exp.performance.get("accuracy") {
386 let weight = performance * adaptation_weight;
387 weighted_sum += exp_param * weight;
388 weight_sum += weight;
389 }
390 }
391 }
392
393 if weight_sum > 0.0 {
394 let adapted_value = weighted_sum / weight_sum;
395 *value = (1.0 - adaptation_weight) * *value + adaptation_weight * adapted_value;
396 }
397 }
398
399 Ok(adapted_params)
400 }
401
402 fn parameter_averaging(
404 &self,
405 current_params: &HashMap<String, f64>,
406 experiences: &[&Experience],
407 num_similar: usize,
408 similarity_threshold: f64,
409 ) -> SklResult<HashMap<String, f64>> {
410 let mut adapted_params = current_params.clone();
411
412 let similar_exps: Vec<&Experience> = experiences
414 .iter()
415 .filter(|exp| {
416 exp.performance
417 .get("accuracy")
418 .is_some_and(|&acc| acc >= similarity_threshold)
419 })
420 .take(num_similar)
421 .copied()
422 .collect();
423
424 if similar_exps.is_empty() {
425 return Ok(adapted_params);
426 }
427
428 for (key, value) in &mut adapted_params {
430 let mut sum = *value;
431 let mut count = 1; for exp in &similar_exps {
434 if let Some(&exp_param) = exp.parameters.get(key) {
435 sum += exp_param;
436 count += 1;
437 }
438 }
439
440 *value = sum / f64::from(count);
441 }
442
443 Ok(adapted_params)
444 }
445
446 fn gradient_based_adaptation(
448 &self,
449 current_params: &HashMap<String, f64>,
450 experiences: &[&Experience],
451 inner_lr: f64,
452 _outer_lr: f64,
453 num_inner_steps: usize,
454 ) -> SklResult<HashMap<String, f64>> {
455 let mut adapted_params = current_params.clone();
456
457 if experiences.is_empty() {
458 return Ok(adapted_params);
459 }
460
461 for _ in 0..num_inner_steps {
463 let mut gradients = HashMap::new();
464
465 for exp in experiences {
467 for (key, &exp_param) in &exp.parameters {
468 if let Some(¤t_param) = adapted_params.get(key) {
469 if let Some(&performance) = exp.performance.get("loss") {
470 let gradient = (current_param - exp_param) * performance;
472 *gradients.entry(key.clone()).or_insert(0.0) += gradient;
473 }
474 }
475 }
476 }
477
478 for (key, gradient) in gradients {
480 if let Some(param) = adapted_params.get_mut(&key) {
481 *param -= inner_lr * gradient / experiences.len() as f64;
482 }
483 }
484 }
485
486 Ok(adapted_params)
487 }
488}
489
490#[derive(Debug)]
492pub struct MetaLearningPipeline<S = Untrained> {
493 state: S,
494 base_estimator: Option<Box<dyn PipelinePredictor>>,
495 experience_storage: ExperienceStorage,
496 adaptation_strategy: AdaptationStrategy,
497 meta_parameters: HashMap<String, f64>,
498}
499
500#[derive(Debug)]
502pub struct MetaLearningPipelineTrained {
503 fitted_estimator: Box<dyn PipelinePredictor>,
504 experience_storage: ExperienceStorage,
505 adaptation_strategy: AdaptationStrategy,
506 meta_parameters: HashMap<String, f64>,
507 n_features_in: usize,
508 feature_names_in: Option<Vec<String>>,
509}
510
511impl MetaLearningPipeline<Untrained> {
512 #[must_use]
514 pub fn new(base_estimator: Box<dyn PipelinePredictor>) -> Self {
515 Self {
516 state: Untrained,
517 base_estimator: Some(base_estimator),
518 experience_storage: ExperienceStorage::new(1000), adaptation_strategy: AdaptationStrategy::FineTuning {
520 learning_rate: 0.01,
521 num_steps: 10,
522 },
523 meta_parameters: HashMap::new(),
524 }
525 }
526
527 #[must_use]
529 pub fn experience_storage(mut self, storage: ExperienceStorage) -> Self {
530 self.experience_storage = storage;
531 self
532 }
533
534 #[must_use]
536 pub fn adaptation_strategy(mut self, strategy: AdaptationStrategy) -> Self {
537 self.adaptation_strategy = strategy;
538 self
539 }
540
541 #[must_use]
543 pub fn meta_parameters(mut self, params: HashMap<String, f64>) -> Self {
544 self.meta_parameters = params;
545 self
546 }
547
548 pub fn add_experience(&mut self, experience: Experience) {
550 self.experience_storage.add_experience(experience);
551 }
552}
553
554impl Estimator for MetaLearningPipeline<Untrained> {
555 type Config = ();
556 type Error = SklearsError;
557 type Float = Float;
558
559 fn config(&self) -> &Self::Config {
560 &()
561 }
562}
563
564impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>>
565 for MetaLearningPipeline<Untrained>
566{
567 type Fitted = MetaLearningPipeline<MetaLearningPipelineTrained>;
568
569 fn fit(
570 self,
571 x: &ArrayView2<'_, Float>,
572 y: &Option<&ArrayView1<'_, Float>>,
573 ) -> SklResult<Self::Fitted> {
574 let mut base_estimator = self
575 .base_estimator
576 .ok_or_else(|| SklearsError::InvalidInput("No base estimator provided".to_string()))?;
577
578 if let Some(y_values) = y.as_ref() {
579 let x_f64 = x.mapv(|v| v);
581 let similar_experiences = self
582 .experience_storage
583 .get_similar_experiences(&x_f64.view(), 5);
584
585 let adapted_params = self.adaptation_strategy.adapt(
587 &self.meta_parameters,
588 &similar_experiences,
589 &x_f64.view(),
590 )?;
591
592 base_estimator.fit(x, y_values)?;
594
595 Ok(MetaLearningPipeline {
596 state: MetaLearningPipelineTrained {
597 fitted_estimator: base_estimator,
598 experience_storage: self.experience_storage,
599 adaptation_strategy: self.adaptation_strategy,
600 meta_parameters: adapted_params,
601 n_features_in: x.ncols(),
602 feature_names_in: None,
603 },
604 base_estimator: None,
605 experience_storage: ExperienceStorage::new(0), adaptation_strategy: AdaptationStrategy::FineTuning {
607 learning_rate: 0.01,
608 num_steps: 1,
609 },
610 meta_parameters: HashMap::new(),
611 })
612 } else {
613 Err(SklearsError::InvalidInput(
614 "Target values required for meta-learning".to_string(),
615 ))
616 }
617 }
618}
619
620impl MetaLearningPipeline<MetaLearningPipelineTrained> {
621 pub fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array1<f64>> {
623 self.state.fitted_estimator.predict(x)
624 }
625
626 pub fn adapt_to_task(
628 &mut self,
629 task_id: String,
630 x: &ArrayView2<'_, Float>,
631 y: &ArrayView1<'_, Float>,
632 ) -> SklResult<()> {
633 let x_f64 = x.mapv(|v| v);
635 let similar_experiences = self
636 .state
637 .experience_storage
638 .get_similar_experiences(&x_f64.view(), 5);
639
640 let adapted_params = self.state.adaptation_strategy.adapt(
642 &self.state.meta_parameters,
643 &similar_experiences,
644 &x_f64.view(),
645 )?;
646
647 self.state.meta_parameters = adapted_params;
649
650 let experience = Experience::new(task_id, x_f64, y.mapv(|v| v))
652 .with_parameters(self.state.meta_parameters.clone());
653
654 self.state.experience_storage.add_experience(experience);
655
656 Ok(())
657 }
658
659 #[must_use]
661 pub fn experience_storage(&self) -> &ExperienceStorage {
662 &self.state.experience_storage
663 }
664
665 #[must_use]
667 pub fn meta_parameters(&self) -> &HashMap<String, f64> {
668 &self.state.meta_parameters
669 }
670}
671
672#[allow(non_snake_case)]
673#[cfg(test)]
674mod tests {
675 use super::*;
676 use crate::MockPredictor;
677 use scirs2_core::ndarray::array;
678
679 #[test]
680 fn test_experience_storage() {
681 let mut storage = ExperienceStorage::new(3);
682
683 let exp1 = Experience::new(
684 "task1".to_string(),
685 array![[1.0, 2.0], [3.0, 4.0]],
686 array![1.0, 0.0],
687 );
688
689 storage.add_experience(exp1);
690 assert_eq!(storage.len(), 1);
691
692 let task_exps = storage.get_task_experiences("task1");
693 assert_eq!(task_exps.len(), 1);
694 }
695
696 #[test]
697 fn test_meta_learning_pipeline() {
698 let x = array![[1.0, 2.0], [3.0, 4.0]];
699 let y = array![1.0, 0.0];
700
701 let base_estimator = Box::new(MockPredictor::new());
702 let mut pipeline = MetaLearningPipeline::new(base_estimator);
703
704 let experience = Experience::new(
706 "task1".to_string(),
707 x.mapv(|v| v as f64),
708 y.mapv(|v| v as f64),
709 );
710 pipeline.add_experience(experience);
711
712 let fitted_pipeline = pipeline.fit(&x.view(), &Some(&y.view())).unwrap();
713 let predictions = fitted_pipeline.predict(&x.view()).unwrap();
714
715 assert_eq!(predictions.len(), x.nrows());
716 }
717
718 #[test]
719 fn test_adaptation_strategies() {
720 let mut params = HashMap::new();
721 params.insert("param1".to_string(), 1.0);
722 params.insert("param2".to_string(), 2.0);
723
724 let experience = Experience::new("task1".to_string(), array![[1.0, 2.0]], array![1.0])
725 .with_parameters(params.clone())
726 .with_performance([("accuracy".to_string(), 0.8)].iter().cloned().collect());
727
728 let experiences = vec![&experience];
729 let features = array![[1.0, 2.0]];
730
731 let strategy = AdaptationStrategy::FineTuning {
732 learning_rate: 0.1,
733 num_steps: 5,
734 };
735
736 let adapted = strategy
737 .adapt(¶ms, &experiences, &features.view())
738 .unwrap();
739 assert_eq!(adapted.len(), 2);
740 }
741}