sklears_compose/
meta_learning.rs

1//! Meta-learning pipeline components
2//!
3//! This module provides meta-learning capabilities including experience storage,
4//! adaptation strategies, and meta-learning pipeline components.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
7use sklears_core::{
8    error::Result as SklResult,
9    prelude::{Predict, SklearsError},
10    traits::{Estimator, Fit, Untrained},
11    types::Float,
12};
13use std::collections::{HashMap, VecDeque};
14
15use crate::{PipelinePredictor, PipelineStep};
16
17/// Experience entry for meta-learning
18#[derive(Debug, Clone)]
19pub struct Experience {
20    /// Task identifier
21    pub task_id: String,
22    /// Input features for the task
23    pub features: Array2<f64>,
24    /// Target values for the task
25    pub targets: Array1<f64>,
26    /// Task metadata (e.g., task type, domain)
27    pub metadata: HashMap<String, String>,
28    /// Performance metrics achieved on this task
29    pub performance: HashMap<String, f64>,
30    /// Model parameters used for this task
31    pub parameters: HashMap<String, f64>,
32}
33
34impl Experience {
35    /// Create a new experience entry
36    #[must_use]
37    pub fn new(task_id: String, features: Array2<f64>, targets: Array1<f64>) -> Self {
38        Self {
39            task_id,
40            features,
41            targets,
42            metadata: HashMap::new(),
43            performance: HashMap::new(),
44            parameters: HashMap::new(),
45        }
46    }
47
48    /// Add metadata to the experience
49    #[must_use]
50    pub fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
51        self.metadata = metadata;
52        self
53    }
54
55    /// Add performance metrics to the experience
56    #[must_use]
57    pub fn with_performance(mut self, performance: HashMap<String, f64>) -> Self {
58        self.performance = performance;
59        self
60    }
61
62    /// Add model parameters to the experience
63    #[must_use]
64    pub fn with_parameters(mut self, parameters: HashMap<String, f64>) -> Self {
65        self.parameters = parameters;
66        self
67    }
68}
69
70/// Experience storage for meta-learning
71#[derive(Debug, Clone)]
72pub struct ExperienceStorage {
73    /// Maximum number of experiences to store
74    max_size: usize,
75    /// Stored experiences
76    experiences: VecDeque<Experience>,
77    /// Index by task ID for fast lookup
78    task_index: HashMap<String, Vec<usize>>,
79}
80
81impl ExperienceStorage {
82    /// Create a new experience storage
83    #[must_use]
84    pub fn new(max_size: usize) -> Self {
85        Self {
86            max_size,
87            experiences: VecDeque::new(),
88            task_index: HashMap::new(),
89        }
90    }
91
92    /// Add an experience to the storage
93    pub fn add_experience(&mut self, experience: Experience) {
94        let task_id = experience.task_id.clone();
95
96        // Remove oldest experience if at capacity
97        if self.experiences.len() >= self.max_size {
98            if let Some(removed) = self.experiences.pop_front() {
99                self.remove_from_index(&removed.task_id, 0);
100            }
101        }
102
103        // Add new experience
104        let index = self.experiences.len();
105        self.experiences.push_back(experience);
106
107        // Update index
108        self.task_index.entry(task_id).or_default().push(index);
109    }
110
111    /// Get experiences for a specific task
112    #[must_use]
113    pub fn get_task_experiences(&self, task_id: &str) -> Vec<&Experience> {
114        if let Some(indices) = self.task_index.get(task_id) {
115            indices
116                .iter()
117                .filter_map(|&i| self.experiences.get(i))
118                .collect()
119        } else {
120            Vec::new()
121        }
122    }
123
124    /// Get all experiences
125    #[must_use]
126    pub fn get_all_experiences(&self) -> Vec<&Experience> {
127        self.experiences.iter().collect()
128    }
129
130    /// Get the most similar experiences based on feature similarity
131    #[must_use]
132    pub fn get_similar_experiences(
133        &self,
134        features: &ArrayView2<'_, f64>,
135        k: usize,
136    ) -> Vec<&Experience> {
137        let mut similarities = Vec::new();
138
139        for exp in &self.experiences {
140            let similarity = self.compute_similarity(features, &exp.features.view());
141            similarities.push((similarity, exp));
142        }
143
144        // Sort by similarity (descending)
145        similarities.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
146
147        // Return top k
148        similarities
149            .into_iter()
150            .take(k)
151            .map(|(_, exp)| exp)
152            .collect()
153    }
154
155    /// Compute similarity between two feature sets (cosine similarity)
156    fn compute_similarity(
157        &self,
158        features1: &ArrayView2<'_, f64>,
159        features2: &ArrayView2<'_, f64>,
160    ) -> f64 {
161        if features1.ncols() != features2.ncols() {
162            return 0.0;
163        }
164
165        // Use mean of features for simplicity
166        let mean1 = features1.mean_axis(Axis(0)).unwrap();
167        let mean2 = features2.mean_axis(Axis(0)).unwrap();
168
169        // Cosine similarity
170        let dot_product = mean1.dot(&mean2);
171        let norm1 = mean1.mapv(|x| x * x).sum().sqrt();
172        let norm2 = mean2.mapv(|x| x * x).sum().sqrt();
173
174        if norm1 == 0.0 || norm2 == 0.0 {
175            0.0
176        } else {
177            dot_product / (norm1 * norm2)
178        }
179    }
180
181    /// Remove an experience from the index
182    fn remove_from_index(&mut self, task_id: &str, index: usize) {
183        if let Some(indices) = self.task_index.get_mut(task_id) {
184            indices.retain(|&i| i != index);
185            if indices.is_empty() {
186                self.task_index.remove(task_id);
187            }
188        }
189    }
190
191    /// Get the number of stored experiences
192    #[must_use]
193    pub fn len(&self) -> usize {
194        self.experiences.len()
195    }
196
197    /// Check if storage is empty
198    #[must_use]
199    pub fn is_empty(&self) -> bool {
200        self.experiences.is_empty()
201    }
202}
203
204/// Adaptation strategy for meta-learning
205#[derive(Debug, Clone)]
206pub enum AdaptationStrategy {
207    /// Fine-tune all parameters
208    FineTuning {
209        learning_rate: f64,
210        num_steps: usize,
211    },
212    /// Fine-tune only the last layer
213    LastLayerFineTuning {
214        learning_rate: f64,
215        num_steps: usize,
216    },
217    /// Feature-based adaptation
218    FeatureAdaptation { adaptation_weight: f64 },
219    /// Parameter averaging from similar tasks
220    ParameterAveraging {
221        num_similar: usize,
222        similarity_threshold: f64,
223    },
224    /// Gradient-based meta-learning (MAML-style)
225    GradientBased {
226        inner_lr: f64,
227        outer_lr: f64,
228        num_inner_steps: usize,
229    },
230}
231
232impl AdaptationStrategy {
233    /// Apply the adaptation strategy
234    pub fn adapt(
235        &self,
236        current_params: &HashMap<String, f64>,
237        experiences: &[&Experience],
238        task_features: &ArrayView2<'_, f64>,
239    ) -> SklResult<HashMap<String, f64>> {
240        match self {
241            AdaptationStrategy::FineTuning {
242                learning_rate,
243                num_steps,
244            } => self.fine_tune_adaptation(current_params, experiences, *learning_rate, *num_steps),
245            AdaptationStrategy::LastLayerFineTuning {
246                learning_rate,
247                num_steps,
248            } => {
249                self.last_layer_adaptation(current_params, experiences, *learning_rate, *num_steps)
250            }
251            AdaptationStrategy::FeatureAdaptation { adaptation_weight } => self.feature_adaptation(
252                current_params,
253                experiences,
254                task_features,
255                *adaptation_weight,
256            ),
257            AdaptationStrategy::ParameterAveraging {
258                num_similar,
259                similarity_threshold,
260            } => self.parameter_averaging(
261                current_params,
262                experiences,
263                *num_similar,
264                *similarity_threshold,
265            ),
266            AdaptationStrategy::GradientBased {
267                inner_lr,
268                outer_lr,
269                num_inner_steps,
270            } => self.gradient_based_adaptation(
271                current_params,
272                experiences,
273                *inner_lr,
274                *outer_lr,
275                *num_inner_steps,
276            ),
277        }
278    }
279
280    /// Fine-tuning adaptation
281    fn fine_tune_adaptation(
282        &self,
283        current_params: &HashMap<String, f64>,
284        experiences: &[&Experience],
285        learning_rate: f64,
286        num_steps: usize,
287    ) -> SklResult<HashMap<String, f64>> {
288        let mut adapted_params = current_params.clone();
289
290        if experiences.is_empty() {
291            return Ok(adapted_params);
292        }
293
294        // Simple gradient descent simulation
295        for _ in 0..num_steps {
296            for (key, value) in &mut adapted_params {
297                // Compute pseudo-gradient based on experience performance
298                let mut gradient = 0.0;
299                let mut count = 0;
300
301                for exp in experiences {
302                    if let Some(&exp_param) = exp.parameters.get(key) {
303                        if let Some(&performance) = exp.performance.get("accuracy") {
304                            // Simple gradient approximation
305                            gradient += (exp_param - *value) * performance;
306                            count += 1;
307                        }
308                    }
309                }
310
311                if count > 0 {
312                    gradient /= f64::from(count);
313                    *value += learning_rate * gradient;
314                }
315            }
316        }
317
318        Ok(adapted_params)
319    }
320
321    /// Last layer fine-tuning adaptation
322    fn last_layer_adaptation(
323        &self,
324        current_params: &HashMap<String, f64>,
325        experiences: &[&Experience],
326        learning_rate: f64,
327        num_steps: usize,
328    ) -> SklResult<HashMap<String, f64>> {
329        let mut adapted_params = current_params.clone();
330
331        // Only adapt parameters that contain "output" or "final" in their name
332        let last_layer_keys: Vec<String> = adapted_params
333            .keys()
334            .filter(|key| key.contains("output") || key.contains("final"))
335            .cloned()
336            .collect();
337
338        for _ in 0..num_steps {
339            for key in &last_layer_keys {
340                if let Some(value) = adapted_params.get_mut(key) {
341                    let mut gradient = 0.0;
342                    let mut count = 0;
343
344                    for exp in experiences {
345                        if let Some(&exp_param) = exp.parameters.get(key) {
346                            if let Some(&performance) = exp.performance.get("accuracy") {
347                                gradient += (exp_param - *value) * performance;
348                                count += 1;
349                            }
350                        }
351                    }
352
353                    if count > 0 {
354                        gradient /= f64::from(count);
355                        *value += learning_rate * gradient;
356                    }
357                }
358            }
359        }
360
361        Ok(adapted_params)
362    }
363
364    /// Feature-based adaptation
365    fn feature_adaptation(
366        &self,
367        current_params: &HashMap<String, f64>,
368        experiences: &[&Experience],
369        _task_features: &ArrayView2<'_, f64>,
370        adaptation_weight: f64,
371    ) -> SklResult<HashMap<String, f64>> {
372        let mut adapted_params = current_params.clone();
373
374        if experiences.is_empty() {
375            return Ok(adapted_params);
376        }
377
378        // Adapt parameters based on feature similarity
379        for (key, value) in &mut adapted_params {
380            let mut weighted_sum = 0.0;
381            let mut weight_sum = 0.0;
382
383            for exp in experiences {
384                if let Some(&exp_param) = exp.parameters.get(key) {
385                    if let Some(&performance) = exp.performance.get("accuracy") {
386                        let weight = performance * adaptation_weight;
387                        weighted_sum += exp_param * weight;
388                        weight_sum += weight;
389                    }
390                }
391            }
392
393            if weight_sum > 0.0 {
394                let adapted_value = weighted_sum / weight_sum;
395                *value = (1.0 - adaptation_weight) * *value + adaptation_weight * adapted_value;
396            }
397        }
398
399        Ok(adapted_params)
400    }
401
402    /// Parameter averaging adaptation
403    fn parameter_averaging(
404        &self,
405        current_params: &HashMap<String, f64>,
406        experiences: &[&Experience],
407        num_similar: usize,
408        similarity_threshold: f64,
409    ) -> SklResult<HashMap<String, f64>> {
410        let mut adapted_params = current_params.clone();
411
412        // Select top similar experiences based on performance
413        let similar_exps: Vec<&Experience> = experiences
414            .iter()
415            .filter(|exp| {
416                exp.performance
417                    .get("accuracy")
418                    .is_some_and(|&acc| acc >= similarity_threshold)
419            })
420            .take(num_similar)
421            .copied()
422            .collect();
423
424        if similar_exps.is_empty() {
425            return Ok(adapted_params);
426        }
427
428        // Average parameters from similar experiences
429        for (key, value) in &mut adapted_params {
430            let mut sum = *value;
431            let mut count = 1; // Include current parameter
432
433            for exp in &similar_exps {
434                if let Some(&exp_param) = exp.parameters.get(key) {
435                    sum += exp_param;
436                    count += 1;
437                }
438            }
439
440            *value = sum / f64::from(count);
441        }
442
443        Ok(adapted_params)
444    }
445
446    /// Gradient-based adaptation (MAML-style)
447    fn gradient_based_adaptation(
448        &self,
449        current_params: &HashMap<String, f64>,
450        experiences: &[&Experience],
451        inner_lr: f64,
452        _outer_lr: f64,
453        num_inner_steps: usize,
454    ) -> SklResult<HashMap<String, f64>> {
455        let mut adapted_params = current_params.clone();
456
457        if experiences.is_empty() {
458            return Ok(adapted_params);
459        }
460
461        // Simulate inner loop adaptation
462        for _ in 0..num_inner_steps {
463            let mut gradients = HashMap::new();
464
465            // Compute gradients based on experiences
466            for exp in experiences {
467                for (key, &exp_param) in &exp.parameters {
468                    if let Some(&current_param) = adapted_params.get(key) {
469                        if let Some(&performance) = exp.performance.get("loss") {
470                            // Compute gradient as loss-weighted parameter difference
471                            let gradient = (current_param - exp_param) * performance;
472                            *gradients.entry(key.clone()).or_insert(0.0) += gradient;
473                        }
474                    }
475                }
476            }
477
478            // Apply gradients
479            for (key, gradient) in gradients {
480                if let Some(param) = adapted_params.get_mut(&key) {
481                    *param -= inner_lr * gradient / experiences.len() as f64;
482                }
483            }
484        }
485
486        Ok(adapted_params)
487    }
488}
489
490/// Meta-learning pipeline component
491#[derive(Debug)]
492pub struct MetaLearningPipeline<S = Untrained> {
493    state: S,
494    base_estimator: Option<Box<dyn PipelinePredictor>>,
495    experience_storage: ExperienceStorage,
496    adaptation_strategy: AdaptationStrategy,
497    meta_parameters: HashMap<String, f64>,
498}
499
500/// Trained state for `MetaLearningPipeline`
501#[derive(Debug)]
502pub struct MetaLearningPipelineTrained {
503    fitted_estimator: Box<dyn PipelinePredictor>,
504    experience_storage: ExperienceStorage,
505    adaptation_strategy: AdaptationStrategy,
506    meta_parameters: HashMap<String, f64>,
507    n_features_in: usize,
508    feature_names_in: Option<Vec<String>>,
509}
510
511impl MetaLearningPipeline<Untrained> {
512    /// Create a new meta-learning pipeline
513    #[must_use]
514    pub fn new(base_estimator: Box<dyn PipelinePredictor>) -> Self {
515        Self {
516            state: Untrained,
517            base_estimator: Some(base_estimator),
518            experience_storage: ExperienceStorage::new(1000), // Default max size
519            adaptation_strategy: AdaptationStrategy::FineTuning {
520                learning_rate: 0.01,
521                num_steps: 10,
522            },
523            meta_parameters: HashMap::new(),
524        }
525    }
526
527    /// Set the experience storage
528    #[must_use]
529    pub fn experience_storage(mut self, storage: ExperienceStorage) -> Self {
530        self.experience_storage = storage;
531        self
532    }
533
534    /// Set the adaptation strategy
535    #[must_use]
536    pub fn adaptation_strategy(mut self, strategy: AdaptationStrategy) -> Self {
537        self.adaptation_strategy = strategy;
538        self
539    }
540
541    /// Set meta-parameters
542    #[must_use]
543    pub fn meta_parameters(mut self, params: HashMap<String, f64>) -> Self {
544        self.meta_parameters = params;
545        self
546    }
547
548    /// Add an experience to the storage
549    pub fn add_experience(&mut self, experience: Experience) {
550        self.experience_storage.add_experience(experience);
551    }
552}
553
554impl Estimator for MetaLearningPipeline<Untrained> {
555    type Config = ();
556    type Error = SklearsError;
557    type Float = Float;
558
559    fn config(&self) -> &Self::Config {
560        &()
561    }
562}
563
564impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>>
565    for MetaLearningPipeline<Untrained>
566{
567    type Fitted = MetaLearningPipeline<MetaLearningPipelineTrained>;
568
569    fn fit(
570        self,
571        x: &ArrayView2<'_, Float>,
572        y: &Option<&ArrayView1<'_, Float>>,
573    ) -> SklResult<Self::Fitted> {
574        let mut base_estimator = self
575            .base_estimator
576            .ok_or_else(|| SklearsError::InvalidInput("No base estimator provided".to_string()))?;
577
578        if let Some(y_values) = y.as_ref() {
579            // Get similar experiences for adaptation
580            let x_f64 = x.mapv(|v| v);
581            let similar_experiences = self
582                .experience_storage
583                .get_similar_experiences(&x_f64.view(), 5);
584
585            // Adapt meta-parameters based on experiences
586            let adapted_params = self.adaptation_strategy.adapt(
587                &self.meta_parameters,
588                &similar_experiences,
589                &x_f64.view(),
590            )?;
591
592            // Fit the base estimator
593            base_estimator.fit(x, y_values)?;
594
595            Ok(MetaLearningPipeline {
596                state: MetaLearningPipelineTrained {
597                    fitted_estimator: base_estimator,
598                    experience_storage: self.experience_storage,
599                    adaptation_strategy: self.adaptation_strategy,
600                    meta_parameters: adapted_params,
601                    n_features_in: x.ncols(),
602                    feature_names_in: None,
603                },
604                base_estimator: None,
605                experience_storage: ExperienceStorage::new(0), // Placeholder
606                adaptation_strategy: AdaptationStrategy::FineTuning {
607                    learning_rate: 0.01,
608                    num_steps: 1,
609                },
610                meta_parameters: HashMap::new(),
611            })
612        } else {
613            Err(SklearsError::InvalidInput(
614                "Target values required for meta-learning".to_string(),
615            ))
616        }
617    }
618}
619
620impl MetaLearningPipeline<MetaLearningPipelineTrained> {
621    /// Predict using the fitted meta-learning pipeline
622    pub fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array1<f64>> {
623        self.state.fitted_estimator.predict(x)
624    }
625
626    /// Adapt to a new task with limited data
627    pub fn adapt_to_task(
628        &mut self,
629        task_id: String,
630        x: &ArrayView2<'_, Float>,
631        y: &ArrayView1<'_, Float>,
632    ) -> SklResult<()> {
633        // Get similar experiences
634        let x_f64 = x.mapv(|v| v);
635        let similar_experiences = self
636            .state
637            .experience_storage
638            .get_similar_experiences(&x_f64.view(), 5);
639
640        // Adapt parameters
641        let adapted_params = self.state.adaptation_strategy.adapt(
642            &self.state.meta_parameters,
643            &similar_experiences,
644            &x_f64.view(),
645        )?;
646
647        // Update meta-parameters
648        self.state.meta_parameters = adapted_params;
649
650        // Create and store new experience
651        let experience = Experience::new(task_id, x_f64, y.mapv(|v| v))
652            .with_parameters(self.state.meta_parameters.clone());
653
654        self.state.experience_storage.add_experience(experience);
655
656        Ok(())
657    }
658
659    /// Get the experience storage
660    #[must_use]
661    pub fn experience_storage(&self) -> &ExperienceStorage {
662        &self.state.experience_storage
663    }
664
665    /// Get the current meta-parameters
666    #[must_use]
667    pub fn meta_parameters(&self) -> &HashMap<String, f64> {
668        &self.state.meta_parameters
669    }
670}
671
672#[allow(non_snake_case)]
673#[cfg(test)]
674mod tests {
675    use super::*;
676    use crate::MockPredictor;
677    use scirs2_core::ndarray::array;
678
679    #[test]
680    fn test_experience_storage() {
681        let mut storage = ExperienceStorage::new(3);
682
683        let exp1 = Experience::new(
684            "task1".to_string(),
685            array![[1.0, 2.0], [3.0, 4.0]],
686            array![1.0, 0.0],
687        );
688
689        storage.add_experience(exp1);
690        assert_eq!(storage.len(), 1);
691
692        let task_exps = storage.get_task_experiences("task1");
693        assert_eq!(task_exps.len(), 1);
694    }
695
696    #[test]
697    fn test_meta_learning_pipeline() {
698        let x = array![[1.0, 2.0], [3.0, 4.0]];
699        let y = array![1.0, 0.0];
700
701        let base_estimator = Box::new(MockPredictor::new());
702        let mut pipeline = MetaLearningPipeline::new(base_estimator);
703
704        // Add some experience
705        let experience = Experience::new(
706            "task1".to_string(),
707            x.mapv(|v| v as f64),
708            y.mapv(|v| v as f64),
709        );
710        pipeline.add_experience(experience);
711
712        let fitted_pipeline = pipeline.fit(&x.view(), &Some(&y.view())).unwrap();
713        let predictions = fitted_pipeline.predict(&x.view()).unwrap();
714
715        assert_eq!(predictions.len(), x.nrows());
716    }
717
718    #[test]
719    fn test_adaptation_strategies() {
720        let mut params = HashMap::new();
721        params.insert("param1".to_string(), 1.0);
722        params.insert("param2".to_string(), 2.0);
723
724        let experience = Experience::new("task1".to_string(), array![[1.0, 2.0]], array![1.0])
725            .with_parameters(params.clone())
726            .with_performance([("accuracy".to_string(), 0.8)].iter().cloned().collect());
727
728        let experiences = vec![&experience];
729        let features = array![[1.0, 2.0]];
730
731        let strategy = AdaptationStrategy::FineTuning {
732            learning_rate: 0.1,
733            num_steps: 5,
734        };
735
736        let adapted = strategy
737            .adapt(&params, &experiences, &features.view())
738            .unwrap();
739        assert_eq!(adapted.len(), 2);
740    }
741}