sklears_ensemble/
multi_task.rs

1//! Multi-Task Ensemble Methods
2//!
3//! This module provides ensemble methods for multi-task learning, where multiple
4//! related learning tasks are solved jointly to improve generalization performance
5//! by leveraging information shared across tasks.
6
7use crate::bagging::BaggingClassifier;
8use crate::gradient_boosting::{
9    GradientBoostingConfig, GradientBoostingRegressor, TrainedGradientBoostingRegressor,
10};
11use scirs2_core::ndarray::{Array1, Array2};
12use sklears_core::{
13    error::Result as SklResult,
14    prelude::{Predict, SklearsError},
15    traits::{Estimator, Fit, Trained, Untrained},
16};
17use std::collections::HashMap;
18
19/// A trait that combines the `Estimator` and `Predict` traits.
20pub trait MultiTaskEstimator<C, E, F, X, Y>:
21    Estimator<Config = C, Error = E, Float = F> + Predict<X, Y>
22{
23}
24
25impl<T, C, E, F, X, Y> MultiTaskEstimator<C, E, F, X, Y> for T where
26    T: Estimator<Config = C, Error = E, Float = F> + Predict<X, Y>
27{
28}
29
30/// Configuration for multi-task ensemble learning
31#[derive(Debug, Clone)]
32pub struct MultiTaskEnsembleConfig {
33    /// Number of base estimators per task
34    pub n_estimators_per_task: usize,
35    /// Task sharing strategy
36    pub sharing_strategy: TaskSharingStrategy,
37    /// Task similarity metric for adaptive sharing
38    pub similarity_metric: TaskSimilarityMetric,
39    /// Minimum task similarity threshold for sharing
40    pub min_similarity_threshold: f64,
41    /// Task weighting strategy
42    pub task_weighting: TaskWeightingStrategy,
43    /// Whether to use task-specific feature selection
44    pub use_task_specific_features: bool,
45    /// Number of shared features across tasks
46    pub n_shared_features: Option<usize>,
47    /// Regularization strength for task sharing
48    pub sharing_regularization: f64,
49    /// Maximum depth for task hierarchy
50    pub max_task_depth: usize,
51    /// Cross-task validation strategy
52    pub cross_task_validation: CrossTaskValidation,
53}
54
55impl Default for MultiTaskEnsembleConfig {
56    fn default() -> Self {
57        Self {
58            n_estimators_per_task: 10,
59            sharing_strategy: TaskSharingStrategy::SharedRepresentation,
60            similarity_metric: TaskSimilarityMetric::CorrelationBased,
61            min_similarity_threshold: 0.3,
62            task_weighting: TaskWeightingStrategy::Uniform,
63            use_task_specific_features: true,
64            n_shared_features: None,
65            sharing_regularization: 0.1,
66            max_task_depth: 3,
67            cross_task_validation: CrossTaskValidation::LeaveOneTaskOut,
68        }
69    }
70}
71
72/// Strategies for sharing information between tasks
73#[derive(Debug, Clone, PartialEq)]
74pub enum TaskSharingStrategy {
75    /// No sharing between tasks
76    Independent,
77    /// Shared representation learning
78    SharedRepresentation,
79    /// Parameter sharing between tasks
80    ParameterSharing,
81    /// Hierarchical task relationships
82    HierarchicalSharing,
83    /// Adaptive sharing based on task similarity
84    AdaptiveSharing,
85    /// Multi-level sharing with different granularities
86    MultiLevelSharing,
87    /// Transfer learning between tasks
88    TransferLearning,
89}
90
91/// Metrics for measuring task similarity
92#[derive(Debug, Clone, PartialEq)]
93pub enum TaskSimilarityMetric {
94    /// Correlation-based similarity
95    CorrelationBased,
96    /// Feature importance similarity
97    FeatureImportanceSimilarity,
98    /// Model prediction similarity
99    PredictionSimilarity,
100    /// Data distribution similarity
101    DistributionSimilarity,
102    /// Gradient similarity
103    GradientSimilarity,
104    /// Task performance correlation
105    PerformanceCorrelation,
106}
107
108/// Strategies for weighting different tasks
109#[derive(Debug, Clone, PartialEq)]
110pub enum TaskWeightingStrategy {
111    /// Equal weight for all tasks
112    Uniform,
113    /// Weight based on task difficulty
114    DifficultyBased,
115    /// Weight based on task sample size
116    SampleSizeBased,
117    /// Weight based on task performance
118    PerformanceBased,
119    /// Adaptive weighting during training
120    AdaptiveWeighting,
121    /// Weight based on task importance
122    ImportanceBased,
123}
124
125/// Cross-task validation strategies
126#[derive(Debug, Clone, PartialEq)]
127pub enum CrossTaskValidation {
128    /// Leave one task out for validation
129    LeaveOneTaskOut,
130    /// Cross-validation within each task
131    WithinTaskCV,
132    /// Hierarchical cross-validation
133    HierarchicalCV,
134    /// Time-based validation for temporal tasks
135    TemporalCV,
136    /// Stratified validation across tasks
137    StratifiedCV,
138}
139
140/// Multi-task ensemble classifier
141pub struct MultiTaskEnsembleClassifier<State = Untrained> {
142    config: MultiTaskEnsembleConfig,
143    state: std::marker::PhantomData<State>,
144    // Fitted attributes - only populated after training
145    task_models: Option<HashMap<String, Vec<BaggingClassifier<Trained>>>>,
146    shared_models: Option<Vec<BaggingClassifier<Trained>>>,
147    task_similarities: Option<HashMap<(String, String), f64>>,
148    task_weights: Option<HashMap<String, f64>>,
149    feature_selector: Option<MultiTaskFeatureSelector>,
150    task_hierarchy: Option<TaskHierarchy>,
151}
152
153/// Multi-task ensemble regressor
154pub struct MultiTaskEnsembleRegressor<State = Untrained> {
155    config: MultiTaskEnsembleConfig,
156    state: std::marker::PhantomData<State>,
157    // Fitted attributes - only populated after training
158    task_models: Option<HashMap<String, Vec<TrainedGradientBoostingRegressor>>>,
159    shared_models: Option<Vec<TrainedGradientBoostingRegressor>>,
160    task_similarities: Option<HashMap<(String, String), f64>>,
161    task_weights: Option<HashMap<String, f64>>,
162    feature_selector: Option<MultiTaskFeatureSelector>,
163    task_hierarchy: Option<TaskHierarchy>,
164}
165
166/// Task hierarchy for hierarchical sharing
167#[derive(Debug, Clone)]
168pub struct TaskHierarchy {
169    /// Parent-child relationships between tasks
170    pub hierarchy: HashMap<String, Vec<String>>,
171    /// Task depth in the hierarchy
172    pub task_depths: HashMap<String, usize>,
173    /// Sharing weights based on hierarchy
174    pub hierarchy_weights: HashMap<(String, String), f64>,
175}
176
177/// Multi-task feature selector
178#[derive(Debug, Clone)]
179pub struct MultiTaskFeatureSelector {
180    /// Task-specific feature masks
181    pub task_feature_masks: HashMap<String, Vec<bool>>,
182    /// Shared feature mask
183    pub shared_feature_mask: Vec<bool>,
184    /// Feature importance scores per task
185    pub task_feature_importances: HashMap<String, Vec<f64>>,
186    /// Global feature importance scores
187    pub global_feature_importances: Vec<f64>,
188}
189
190/// Task-specific training data
191#[derive(Debug, Clone)]
192pub struct TaskData {
193    /// Task identifier
194    pub task_id: String,
195    /// Features for this task
196    pub features: Array2<f64>,
197    /// Labels for this task
198    pub labels: Vec<f64>,
199    /// Sample weights (optional)
200    pub sample_weights: Option<Vec<f64>>,
201    /// Task metadata
202    pub metadata: HashMap<String, String>,
203}
204
205/// Results from multi-task training
206#[derive(Debug, Clone)]
207pub struct MultiTaskTrainingResults {
208    /// Training metrics per task
209    pub task_metrics: HashMap<String, TaskMetrics>,
210    /// Cross-task transfer effects
211    pub transfer_effects: HashMap<(String, String), f64>,
212    /// Final task similarities
213    pub final_similarities: HashMap<(String, String), f64>,
214    /// Convergence information
215    pub convergence_info: ConvergenceInfo,
216}
217
218/// Metrics for individual tasks
219#[derive(Debug, Clone)]
220pub struct TaskMetrics {
221    /// Training accuracy/error
222    pub training_score: f64,
223    /// Validation accuracy/error
224    pub validation_score: f64,
225    /// Number of training samples
226    pub n_samples: usize,
227    /// Training time
228    pub training_time: f64,
229    /// Model complexity measure
230    pub complexity: f64,
231}
232
233/// Convergence information for multi-task training
234#[derive(Debug, Clone)]
235pub struct ConvergenceInfo {
236    /// Number of iterations to convergence
237    pub n_iterations: usize,
238    /// Final loss value
239    pub final_loss: f64,
240    /// Convergence tolerance achieved
241    pub tolerance_achieved: f64,
242    /// Whether convergence was reached
243    pub converged: bool,
244}
245
246impl MultiTaskEnsembleConfig {
247    pub fn builder() -> MultiTaskEnsembleConfigBuilder {
248        MultiTaskEnsembleConfigBuilder::default()
249    }
250}
251
252#[derive(Default)]
253pub struct MultiTaskEnsembleConfigBuilder {
254    config: MultiTaskEnsembleConfig,
255}
256
257impl MultiTaskEnsembleConfigBuilder {
258    pub fn n_estimators_per_task(mut self, n_estimators: usize) -> Self {
259        self.config.n_estimators_per_task = n_estimators;
260        self
261    }
262
263    pub fn sharing_strategy(mut self, strategy: TaskSharingStrategy) -> Self {
264        self.config.sharing_strategy = strategy;
265        self
266    }
267
268    pub fn similarity_metric(mut self, metric: TaskSimilarityMetric) -> Self {
269        self.config.similarity_metric = metric;
270        self
271    }
272
273    pub fn min_similarity_threshold(mut self, threshold: f64) -> Self {
274        self.config.min_similarity_threshold = threshold;
275        self
276    }
277
278    pub fn task_weighting(mut self, weighting: TaskWeightingStrategy) -> Self {
279        self.config.task_weighting = weighting;
280        self
281    }
282
283    pub fn use_task_specific_features(mut self, use_specific: bool) -> Self {
284        self.config.use_task_specific_features = use_specific;
285        self
286    }
287
288    pub fn sharing_regularization(mut self, regularization: f64) -> Self {
289        self.config.sharing_regularization = regularization;
290        self
291    }
292
293    pub fn cross_task_validation(mut self, validation: CrossTaskValidation) -> Self {
294        self.config.cross_task_validation = validation;
295        self
296    }
297
298    pub fn build(self) -> MultiTaskEnsembleConfig {
299        self.config
300    }
301}
302
303impl MultiTaskEnsembleRegressor {
304    pub fn new(config: MultiTaskEnsembleConfig) -> Self {
305        Self {
306            config,
307            state: std::marker::PhantomData,
308            task_models: None,
309            shared_models: None,
310            task_similarities: None,
311            task_weights: None,
312            feature_selector: None,
313            task_hierarchy: None,
314        }
315    }
316
317    pub fn builder() -> MultiTaskEnsembleRegressorBuilder {
318        MultiTaskEnsembleRegressorBuilder::new()
319    }
320
321    /// Fit the multi-task ensemble on multiple tasks
322    pub fn fit_tasks(
323        mut self,
324        tasks: &[TaskData],
325    ) -> SklResult<MultiTaskEnsembleRegressor<Trained>> {
326        if tasks.is_empty() {
327            return Err(SklearsError::InvalidInput("No tasks provided".to_string()));
328        }
329
330        // Initialize task weights
331        self.initialize_task_weights(tasks)?;
332
333        // Build task hierarchy if needed
334        if matches!(
335            self.config.sharing_strategy,
336            TaskSharingStrategy::HierarchicalSharing
337        ) {
338            self.build_task_hierarchy(tasks)?;
339        }
340
341        // Initialize feature selector
342        if self.config.use_task_specific_features {
343            self.initialize_feature_selector(tasks)?;
344        }
345
346        // Compute task similarities
347        self.compute_task_similarities(tasks)?;
348
349        // Train models based on sharing strategy
350        let training_results = match self.config.sharing_strategy {
351            TaskSharingStrategy::Independent => self.train_independent_tasks(tasks)?,
352            TaskSharingStrategy::SharedRepresentation => self.train_shared_representation(tasks)?,
353            TaskSharingStrategy::ParameterSharing => self.train_parameter_sharing(tasks)?,
354            TaskSharingStrategy::HierarchicalSharing => self.train_hierarchical_sharing(tasks)?,
355            TaskSharingStrategy::AdaptiveSharing => self.train_adaptive_sharing(tasks)?,
356            _ => self.train_independent_tasks(tasks)?, // Default fallback
357        };
358
359        let fitted_ensemble = MultiTaskEnsembleRegressor::<Trained> {
360            config: self.config,
361            state: std::marker::PhantomData,
362            task_models: self.task_models,
363            shared_models: self.shared_models,
364            task_similarities: self.task_similarities,
365            task_weights: self.task_weights,
366            feature_selector: self.feature_selector,
367            task_hierarchy: self.task_hierarchy,
368        };
369
370        Ok(fitted_ensemble)
371    }
372
373    /// Initialize task weights based on configuration
374    fn initialize_task_weights(&mut self, tasks: &[TaskData]) -> SklResult<()> {
375        let mut weights = HashMap::new();
376
377        match self.config.task_weighting {
378            TaskWeightingStrategy::Uniform => {
379                let weight = 1.0 / tasks.len() as f64;
380                for task in tasks {
381                    weights.insert(task.task_id.clone(), weight);
382                }
383            }
384            TaskWeightingStrategy::SampleSizeBased => {
385                let total_samples: usize = tasks.iter().map(|t| t.features.shape()[0]).sum();
386                for task in tasks {
387                    let weight = task.features.shape()[0] as f64 / total_samples as f64;
388                    weights.insert(task.task_id.clone(), weight);
389                }
390            }
391            _ => {
392                // Default to uniform for now
393                let weight = 1.0 / tasks.len() as f64;
394                for task in tasks {
395                    weights.insert(task.task_id.clone(), weight);
396                }
397            }
398        }
399
400        self.task_weights = Some(weights);
401        Ok(())
402    }
403
404    /// Build task hierarchy for hierarchical sharing
405    fn build_task_hierarchy(&mut self, tasks: &[TaskData]) -> SklResult<()> {
406        let mut hierarchy = HashMap::new();
407        let mut task_depths = HashMap::new();
408        let mut hierarchy_weights = HashMap::new();
409
410        // Simple hierarchical construction based on task similarities
411        // In practice, this could be based on domain knowledge or learned
412        for (i, task1) in tasks.iter().enumerate() {
413            task_depths.insert(task1.task_id.clone(), 0);
414            hierarchy.insert(task1.task_id.clone(), Vec::new());
415
416            for (j, task2) in tasks.iter().enumerate() {
417                if i != j {
418                    // Set hierarchical weight based on some similarity measure
419                    let weight = 1.0 / (i.abs_diff(j) + 1) as f64;
420                    hierarchy_weights
421                        .insert((task1.task_id.clone(), task2.task_id.clone()), weight);
422                }
423            }
424        }
425
426        self.task_hierarchy = Some(TaskHierarchy {
427            hierarchy,
428            task_depths,
429            hierarchy_weights,
430        });
431
432        Ok(())
433    }
434
435    /// Initialize feature selector for task-specific features
436    fn initialize_feature_selector(&mut self, tasks: &[TaskData]) -> SklResult<()> {
437        if tasks.is_empty() {
438            return Ok(());
439        }
440
441        let n_features = tasks[0].features.shape()[1];
442        let mut task_feature_masks = HashMap::new();
443        let mut task_feature_importances = HashMap::new();
444
445        // For now, use simple feature selection based on variance
446        for task in tasks {
447            let mut feature_mask = vec![true; n_features];
448            let mut feature_importances = vec![0.0; n_features];
449
450            // Calculate feature variances as a simple importance measure
451            for j in 0..n_features {
452                let column: Vec<f64> = (0..task.features.shape()[0])
453                    .map(|i| task.features[[i, j]])
454                    .collect();
455                let mean = column.iter().sum::<f64>() / column.len() as f64;
456                let variance =
457                    column.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / column.len() as f64;
458
459                feature_importances[j] = variance;
460                feature_mask[j] = variance > 1e-8; // Keep features with non-zero variance
461            }
462
463            task_feature_masks.insert(task.task_id.clone(), feature_mask);
464            task_feature_importances.insert(task.task_id.clone(), feature_importances);
465        }
466
467        // Global feature importance (average across tasks)
468        let mut global_feature_importances = vec![0.0; n_features];
469        for j in 0..n_features {
470            let sum: f64 = task_feature_importances
471                .values()
472                .map(|importances| importances[j])
473                .sum();
474            global_feature_importances[j] = sum / tasks.len() as f64;
475        }
476
477        // Shared feature mask (features important across multiple tasks)
478        let shared_feature_mask: Vec<bool> = (0..n_features)
479            .map(|j| {
480                let important_count = task_feature_importances
481                    .values()
482                    .filter(|importances| importances[j] > global_feature_importances[j] * 0.5)
483                    .count();
484                important_count >= tasks.len() / 2
485            })
486            .collect();
487
488        self.feature_selector = Some(MultiTaskFeatureSelector {
489            task_feature_masks,
490            shared_feature_mask,
491            task_feature_importances,
492            global_feature_importances,
493        });
494
495        Ok(())
496    }
497
498    /// Compute similarities between tasks
499    fn compute_task_similarities(&mut self, tasks: &[TaskData]) -> SklResult<()> {
500        let mut similarities = HashMap::new();
501
502        for (i, task1) in tasks.iter().enumerate() {
503            for task2 in tasks.iter().skip(i + 1) {
504                let similarity = self.calculate_task_similarity(task1, task2)?;
505                similarities.insert((task1.task_id.clone(), task2.task_id.clone()), similarity);
506                similarities.insert((task2.task_id.clone(), task1.task_id.clone()), similarity);
507            }
508        }
509
510        self.task_similarities = Some(similarities);
511        Ok(())
512    }
513
514    /// Calculate similarity between two tasks
515    fn calculate_task_similarity(&self, task1: &TaskData, task2: &TaskData) -> SklResult<f64> {
516        match self.config.similarity_metric {
517            TaskSimilarityMetric::CorrelationBased => self.correlation_similarity(task1, task2),
518            TaskSimilarityMetric::DistributionSimilarity => {
519                self.distribution_similarity(task1, task2)
520            }
521            _ => {
522                // Default to correlation-based similarity
523                self.correlation_similarity(task1, task2)
524            }
525        }
526    }
527
528    /// Calculate correlation-based similarity between tasks
529    fn correlation_similarity(&self, task1: &TaskData, task2: &TaskData) -> SklResult<f64> {
530        // Simple correlation between target variables if they have the same length
531        if task1.labels.len() != task2.labels.len() {
532            return Ok(0.0); // No similarity if different lengths
533        }
534
535        let n = task1.labels.len();
536        if n < 2 {
537            return Ok(0.0);
538        }
539
540        let mean1 = task1.labels.iter().sum::<f64>() / n as f64;
541        let mean2 = task2.labels.iter().sum::<f64>() / n as f64;
542
543        let mut numerator = 0.0;
544        let mut denom1 = 0.0;
545        let mut denom2 = 0.0;
546
547        for i in 0..n {
548            let diff1 = task1.labels[i] - mean1;
549            let diff2 = task2.labels[i] - mean2;
550            numerator += diff1 * diff2;
551            denom1 += diff1 * diff1;
552            denom2 += diff2 * diff2;
553        }
554
555        if denom1 * denom2 > 0.0 {
556            Ok(numerator / (denom1 * denom2).sqrt())
557        } else {
558            Ok(0.0)
559        }
560    }
561
562    /// Calculate distribution similarity between tasks
563    fn distribution_similarity(&self, task1: &TaskData, task2: &TaskData) -> SklResult<f64> {
564        // Simple approach: compare feature means and variances
565        let n_features1 = task1.features.shape()[1];
566        let n_features2 = task2.features.shape()[1];
567
568        if n_features1 != n_features2 {
569            return Ok(0.0);
570        }
571
572        let mut similarity_sum = 0.0;
573
574        for j in 0..n_features1 {
575            let col1: Vec<f64> = (0..task1.features.shape()[0])
576                .map(|i| task1.features[[i, j]])
577                .collect();
578            let col2: Vec<f64> = (0..task2.features.shape()[0])
579                .map(|i| task2.features[[i, j]])
580                .collect();
581
582            let mean1 = col1.iter().sum::<f64>() / col1.len() as f64;
583            let mean2 = col2.iter().sum::<f64>() / col2.len() as f64;
584
585            let var1 = col1.iter().map(|&x| (x - mean1).powi(2)).sum::<f64>() / col1.len() as f64;
586            let var2 = col2.iter().map(|&x| (x - mean2).powi(2)).sum::<f64>() / col2.len() as f64;
587
588            // Similarity based on mean and variance differences
589            let mean_sim = 1.0 - (mean1 - mean2).abs() / (mean1.abs() + mean2.abs() + 1e-8);
590            let var_sim = 1.0 - (var1 - var2).abs() / (var1 + var2 + 1e-8);
591
592            similarity_sum += (mean_sim + var_sim) / 2.0;
593        }
594
595        Ok(similarity_sum / n_features1 as f64)
596    }
597
598    /// Train independent models for each task
599    fn train_independent_tasks(
600        &mut self,
601        tasks: &[TaskData],
602    ) -> SklResult<MultiTaskTrainingResults> {
603        let mut task_metrics = HashMap::new();
604
605        // Initialize task_models
606        if self.task_models.is_none() {
607            self.task_models = Some(HashMap::new());
608        }
609
610        for task in tasks {
611            let mut models = Vec::new();
612
613            for _ in 0..self.config.n_estimators_per_task {
614                let gb_config = GradientBoostingConfig {
615                    n_estimators: 50,
616                    learning_rate: 0.1,
617                    max_depth: 6,
618                    ..Default::default()
619                };
620
621                let y_array = Array1::from_vec(task.labels.clone());
622                let model = GradientBoostingRegressor::builder()
623                    .n_estimators(50)
624                    .learning_rate(0.1)
625                    .max_depth(6)
626                    .build()
627                    .fit(&task.features, &y_array)?;
628
629                models.push(model);
630            }
631
632            // Calculate task metrics
633            let predictions = self.predict_task_ensemble(&models, &task.features)?;
634            let mse = self.calculate_mse(&predictions, &task.labels);
635
636            task_metrics.insert(
637                task.task_id.clone(),
638                TaskMetrics {
639                    training_score: mse,
640                    validation_score: mse, // Would be different in practice
641                    n_samples: task.features.shape()[0],
642                    training_time: 0.0, // Would measure actual time
643                    complexity: models.len() as f64,
644                },
645            );
646
647            self.task_models
648                .as_mut()
649                .unwrap()
650                .insert(task.task_id.clone(), models);
651        }
652
653        Ok(MultiTaskTrainingResults {
654            task_metrics,
655            transfer_effects: HashMap::new(),
656            final_similarities: self.task_similarities.clone().unwrap_or_default(),
657            convergence_info: ConvergenceInfo {
658                n_iterations: 1,
659                final_loss: 0.0,
660                tolerance_achieved: 0.0,
661                converged: true,
662            },
663        })
664    }
665
666    /// Train with shared representation learning
667    fn train_shared_representation(
668        &mut self,
669        tasks: &[TaskData],
670    ) -> SklResult<MultiTaskTrainingResults> {
671        // First train shared models on combined data
672        let combined_data = self.combine_task_data(tasks)?;
673
674        // Initialize shared_models
675        if self.shared_models.is_none() {
676            self.shared_models = Some(Vec::new());
677        }
678
679        for _ in 0..self.config.n_estimators_per_task {
680            let gb_config = GradientBoostingConfig {
681                n_estimators: 30,
682                learning_rate: 0.1,
683                max_depth: 4,
684                ..Default::default()
685            };
686
687            let y_array = Array1::from_vec(combined_data.labels.clone());
688            let shared_model = GradientBoostingRegressor::builder()
689                .n_estimators(30)
690                .learning_rate(0.1)
691                .max_depth(4)
692                .build()
693                .fit(&combined_data.features, &y_array)?;
694
695            self.shared_models.as_mut().unwrap().push(shared_model);
696        }
697
698        // Then train task-specific models
699        self.train_independent_tasks(tasks)
700    }
701
702    /// Train with parameter sharing between similar tasks
703    fn train_parameter_sharing(
704        &mut self,
705        tasks: &[TaskData],
706    ) -> SklResult<MultiTaskTrainingResults> {
707        // Group similar tasks
708        let task_groups = self.group_similar_tasks(tasks)?;
709
710        let mut task_metrics = HashMap::new();
711
712        for group in task_groups {
713            // Train shared models for this group
714            let group_data = self.combine_group_data(tasks, &group)?;
715
716            let mut group_models = Vec::new();
717            for _ in 0..self.config.n_estimators_per_task {
718                let gb_config = GradientBoostingConfig {
719                    n_estimators: 40,
720                    learning_rate: 0.1,
721                    max_depth: 5,
722                    ..Default::default()
723                };
724
725                let y_array = Array1::from_vec(group_data.labels.clone());
726                let model = GradientBoostingRegressor::builder()
727                    .n_estimators(40)
728                    .learning_rate(0.1)
729                    .max_depth(5)
730                    .build()
731                    .fit(&group_data.features, &y_array)?;
732
733                group_models.push(model);
734            }
735
736            // Calculate metrics for tasks in this group and create separate model sets
737            for task_id in &group {
738                let task = tasks.iter().find(|t| &t.task_id == task_id).unwrap();
739
740                // Calculate metrics for this task
741                let predictions = self.predict_task_ensemble(&group_models, &task.features)?;
742                let mse = self.calculate_mse(&predictions, &task.labels);
743
744                task_metrics.insert(
745                    task_id.clone(),
746                    TaskMetrics {
747                        training_score: mse,
748                        validation_score: mse,
749                        n_samples: task.features.shape()[0],
750                        training_time: 0.0,
751                        complexity: group_models.len() as f64,
752                    },
753                );
754
755                // For each task, train a separate set of models with the same configuration
756                let mut task_models = Vec::new();
757                for _ in 0..self.config.n_estimators_per_task {
758                    let y_array = Array1::from_vec(group_data.labels.clone());
759                    let model = GradientBoostingRegressor::builder()
760                        .n_estimators(40)
761                        .learning_rate(0.1)
762                        .max_depth(5)
763                        .build()
764                        .fit(&group_data.features, &y_array)?;
765                    task_models.push(model);
766                }
767
768                self.task_models
769                    .as_mut()
770                    .unwrap()
771                    .insert(task_id.clone(), task_models);
772            }
773        }
774
775        Ok(MultiTaskTrainingResults {
776            task_metrics,
777            transfer_effects: HashMap::new(),
778            final_similarities: self.task_similarities.clone().unwrap_or_default(),
779            convergence_info: ConvergenceInfo {
780                n_iterations: 1,
781                final_loss: 0.0,
782                tolerance_achieved: 0.0,
783                converged: true,
784            },
785        })
786    }
787
788    /// Train with hierarchical sharing
789    fn train_hierarchical_sharing(
790        &mut self,
791        tasks: &[TaskData],
792    ) -> SklResult<MultiTaskTrainingResults> {
793        // For now, delegate to parameter sharing
794        // In practice, this would implement hierarchical relationships
795        self.train_parameter_sharing(tasks)
796    }
797
798    /// Train with adaptive sharing based on task similarities
799    fn train_adaptive_sharing(
800        &mut self,
801        tasks: &[TaskData],
802    ) -> SklResult<MultiTaskTrainingResults> {
803        let mut task_metrics = HashMap::new();
804
805        for task in tasks {
806            let mut models = Vec::new();
807
808            // Find similar tasks for this task
809            let similar_tasks = self.find_similar_tasks(&task.task_id, tasks);
810
811            if similar_tasks.len() > 1 {
812                // Train with data from similar tasks
813                let combined_data = self.combine_similar_task_data(tasks, &similar_tasks)?;
814
815                for _ in 0..self.config.n_estimators_per_task {
816                    let gb_config = GradientBoostingConfig {
817                        n_estimators: 50,
818                        learning_rate: 0.1,
819                        max_depth: 6,
820                        ..Default::default()
821                    };
822
823                    let y_array = Array1::from_vec(combined_data.labels.clone());
824                    let model = GradientBoostingRegressor::builder()
825                        .n_estimators(50)
826                        .learning_rate(0.1)
827                        .max_depth(6)
828                        .build()
829                        .fit(&combined_data.features, &y_array)?;
830
831                    models.push(model);
832                }
833            } else {
834                // Train independently if no similar tasks
835                for _ in 0..self.config.n_estimators_per_task {
836                    let gb_config = GradientBoostingConfig {
837                        n_estimators: 50,
838                        learning_rate: 0.1,
839                        max_depth: 6,
840                        ..Default::default()
841                    };
842
843                    let y_array = Array1::from_vec(task.labels.clone());
844                    let model = GradientBoostingRegressor::builder()
845                        .n_estimators(50)
846                        .learning_rate(0.1)
847                        .max_depth(6)
848                        .build()
849                        .fit(&task.features, &y_array)?;
850
851                    models.push(model);
852                }
853            }
854
855            // Calculate metrics
856            let predictions = self.predict_task_ensemble(&models, &task.features)?;
857            let mse = self.calculate_mse(&predictions, &task.labels);
858
859            task_metrics.insert(
860                task.task_id.clone(),
861                TaskMetrics {
862                    training_score: mse,
863                    validation_score: mse,
864                    n_samples: task.features.shape()[0],
865                    training_time: 0.0,
866                    complexity: models.len() as f64,
867                },
868            );
869
870            self.task_models
871                .as_mut()
872                .unwrap()
873                .insert(task.task_id.clone(), models);
874        }
875
876        Ok(MultiTaskTrainingResults {
877            task_metrics,
878            transfer_effects: HashMap::new(),
879            final_similarities: self.task_similarities.clone().unwrap_or_default(),
880            convergence_info: ConvergenceInfo {
881                n_iterations: 1,
882                final_loss: 0.0,
883                tolerance_achieved: 0.0,
884                converged: true,
885            },
886        })
887    }
888
889    /// Combine data from multiple tasks
890    fn combine_task_data(&self, tasks: &[TaskData]) -> SklResult<TaskData> {
891        if tasks.is_empty() {
892            return Err(SklearsError::InvalidInput(
893                "No tasks to combine".to_string(),
894            ));
895        }
896
897        let total_samples: usize = tasks.iter().map(|t| t.features.shape()[0]).sum();
898        let n_features = tasks[0].features.shape()[1];
899
900        let mut combined_features = Vec::with_capacity(total_samples * n_features);
901        let mut combined_labels = Vec::with_capacity(total_samples);
902
903        for task in tasks {
904            for i in 0..task.features.shape()[0] {
905                for j in 0..n_features {
906                    combined_features.push(task.features[[i, j]]);
907                }
908                combined_labels.push(task.labels[i]);
909            }
910        }
911
912        let features = Array2::from_shape_vec((total_samples, n_features), combined_features)?;
913
914        Ok(TaskData {
915            task_id: "combined".to_string(),
916            features,
917            labels: combined_labels,
918            sample_weights: None,
919            metadata: HashMap::new(),
920        })
921    }
922
923    /// Group similar tasks together
924    fn group_similar_tasks(&self, tasks: &[TaskData]) -> SklResult<Vec<Vec<String>>> {
925        let mut groups = Vec::new();
926        let mut assigned = vec![false; tasks.len()];
927
928        for (i, task) in tasks.iter().enumerate() {
929            if assigned[i] {
930                continue;
931            }
932
933            let mut group = vec![task.task_id.clone()];
934            assigned[i] = true;
935
936            // Find similar tasks
937            for (j, other_task) in tasks.iter().enumerate() {
938                if i != j && !assigned[j] {
939                    let similarity = self
940                        .task_similarities
941                        .as_ref()
942                        .and_then(|similarities| {
943                            similarities.get(&(task.task_id.clone(), other_task.task_id.clone()))
944                        })
945                        .copied()
946                        .unwrap_or(0.0);
947
948                    if similarity >= self.config.min_similarity_threshold {
949                        group.push(other_task.task_id.clone());
950                        assigned[j] = true;
951                    }
952                }
953            }
954
955            groups.push(group);
956        }
957
958        Ok(groups)
959    }
960
961    /// Combine data from a group of tasks
962    fn combine_group_data(&self, tasks: &[TaskData], group: &[String]) -> SklResult<TaskData> {
963        let group_tasks: Vec<TaskData> = tasks
964            .iter()
965            .filter(|t| group.contains(&t.task_id))
966            .cloned()
967            .collect();
968
969        self.combine_task_data(&group_tasks)
970    }
971
972    /// Find tasks similar to a given task
973    fn find_similar_tasks(&self, task_id: &str, tasks: &[TaskData]) -> Vec<String> {
974        let mut similar_tasks = vec![task_id.to_string()];
975
976        for task in tasks {
977            if task.task_id != task_id {
978                let similarity = self
979                    .task_similarities
980                    .as_ref()
981                    .and_then(|similarities| {
982                        similarities.get(&(task_id.to_string(), task.task_id.clone()))
983                    })
984                    .copied()
985                    .unwrap_or(0.0);
986
987                if similarity >= self.config.min_similarity_threshold {
988                    similar_tasks.push(task.task_id.clone());
989                }
990            }
991        }
992
993        similar_tasks
994    }
995
996    /// Combine data from similar tasks
997    fn combine_similar_task_data(
998        &self,
999        tasks: &[TaskData],
1000        similar_task_ids: &[String],
1001    ) -> SklResult<TaskData> {
1002        let similar_tasks: Vec<TaskData> = tasks
1003            .iter()
1004            .filter(|t| similar_task_ids.contains(&t.task_id))
1005            .cloned()
1006            .collect();
1007
1008        self.combine_task_data(&similar_tasks)
1009    }
1010
1011    /// Predict using task ensemble
1012    fn predict_task_ensemble(
1013        &self,
1014        models: &[TrainedGradientBoostingRegressor],
1015        X: &Array2<f64>,
1016    ) -> SklResult<Vec<f64>> {
1017        if models.is_empty() {
1018            return Err(SklearsError::InvalidInput(
1019                "No models in ensemble".to_string(),
1020            ));
1021        }
1022
1023        let mut predictions = vec![0.0; X.shape()[0]];
1024
1025        for model in models {
1026            let pred = model.predict(X)?;
1027            for (i, &p) in pred.iter().enumerate() {
1028                predictions[i] += p;
1029            }
1030        }
1031
1032        // Average predictions
1033        for p in &mut predictions {
1034            *p /= models.len() as f64;
1035        }
1036
1037        Ok(predictions)
1038    }
1039
1040    /// Calculate mean squared error
1041    fn calculate_mse(&self, predictions: &[f64], targets: &[f64]) -> f64 {
1042        if predictions.len() != targets.len() {
1043            return f64::INFINITY;
1044        }
1045
1046        let sum_squared_error: f64 = predictions
1047            .iter()
1048            .zip(targets.iter())
1049            .map(|(&p, &t)| (p - t).powi(2))
1050            .sum();
1051
1052        sum_squared_error / predictions.len() as f64
1053    }
1054}
1055
1056impl MultiTaskEnsembleRegressor<Trained> {
1057    /// Predict for a specific task
1058    pub fn predict_task(&self, task_id: &str, X: &Array2<f64>) -> SklResult<Vec<f64>> {
1059        // Use task-specific models if available
1060        if let Some(models) = self.task_models.as_ref().and_then(|m| m.get(task_id)) {
1061            let mut predictions = self.predict_task_ensemble(models, X)?;
1062
1063            // Add shared model predictions if available
1064            if let Some(shared_models) = self.shared_models.as_ref() {
1065                if !shared_models.is_empty() {
1066                    let shared_predictions = self.predict_task_ensemble(shared_models, X)?;
1067                    for (i, &shared_pred) in shared_predictions.iter().enumerate() {
1068                        predictions[i] = 0.7 * predictions[i] + 0.3 * shared_pred;
1069                    }
1070                }
1071            }
1072
1073            Ok(predictions)
1074        } else {
1075            Err(SklearsError::InvalidInput(format!(
1076                "Task '{}' not found",
1077                task_id
1078            )))
1079        }
1080    }
1081
1082    /// Predict using task ensemble (internal method)
1083    fn predict_task_ensemble(
1084        &self,
1085        models: &[TrainedGradientBoostingRegressor],
1086        X: &Array2<f64>,
1087    ) -> SklResult<Vec<f64>> {
1088        if models.is_empty() {
1089            return Err(SklearsError::InvalidInput(
1090                "No models in ensemble".to_string(),
1091            ));
1092        }
1093
1094        let mut predictions = vec![0.0; X.shape()[0]];
1095
1096        for model in models {
1097            let pred = model.predict(X)?;
1098            for (i, &p) in pred.iter().enumerate() {
1099                predictions[i] += p;
1100            }
1101        }
1102
1103        // Average predictions
1104        for p in &mut predictions {
1105            *p /= models.len() as f64;
1106        }
1107
1108        Ok(predictions)
1109    }
1110
1111    /// Get task similarities
1112    pub fn get_task_similarities(&self) -> &HashMap<(String, String), f64> {
1113        self.task_similarities.as_ref().expect("Model is trained")
1114    }
1115
1116    /// Get task weights
1117    pub fn get_task_weights(&self) -> &HashMap<String, f64> {
1118        self.task_weights.as_ref().expect("Model is trained")
1119    }
1120}
1121
1122pub struct MultiTaskEnsembleRegressorBuilder {
1123    config: MultiTaskEnsembleConfig,
1124}
1125
1126impl Default for MultiTaskEnsembleRegressorBuilder {
1127    fn default() -> Self {
1128        Self::new()
1129    }
1130}
1131
1132impl MultiTaskEnsembleRegressorBuilder {
1133    pub fn new() -> Self {
1134        Self {
1135            config: MultiTaskEnsembleConfig::default(),
1136        }
1137    }
1138
1139    pub fn config(mut self, config: MultiTaskEnsembleConfig) -> Self {
1140        self.config = config;
1141        self
1142    }
1143
1144    pub fn n_estimators_per_task(mut self, n_estimators: usize) -> Self {
1145        self.config.n_estimators_per_task = n_estimators;
1146        self
1147    }
1148
1149    pub fn sharing_strategy(mut self, strategy: TaskSharingStrategy) -> Self {
1150        self.config.sharing_strategy = strategy;
1151        self
1152    }
1153
1154    pub fn min_similarity_threshold(mut self, threshold: f64) -> Self {
1155        self.config.min_similarity_threshold = threshold;
1156        self
1157    }
1158
1159    pub fn build(self) -> MultiTaskEnsembleRegressor {
1160        MultiTaskEnsembleRegressor::new(self.config)
1161    }
1162}
1163
1164impl Estimator for MultiTaskEnsembleRegressor {
1165    type Config = MultiTaskEnsembleConfig;
1166    type Error = SklearsError;
1167    type Float = f64;
1168
1169    fn config(&self) -> &Self::Config {
1170        &self.config
1171    }
1172}
1173
1174#[allow(non_snake_case)]
1175#[cfg(test)]
1176mod tests {
1177    use super::*;
1178    use scirs2_core::ndarray::Array2;
1179
1180    #[test]
1181    fn test_multi_task_config() {
1182        let config = MultiTaskEnsembleConfig::builder()
1183            .n_estimators_per_task(5)
1184            .sharing_strategy(TaskSharingStrategy::SharedRepresentation)
1185            .min_similarity_threshold(0.5)
1186            .build();
1187
1188        assert_eq!(config.n_estimators_per_task, 5);
1189        assert_eq!(
1190            config.sharing_strategy,
1191            TaskSharingStrategy::SharedRepresentation
1192        );
1193        assert_eq!(config.min_similarity_threshold, 0.5);
1194    }
1195
1196    #[test]
1197    fn test_task_data_creation() {
1198        let features = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1199        let labels = vec![1.0, 2.0, 3.0];
1200
1201        let task = TaskData {
1202            task_id: "test_task".to_string(),
1203            features,
1204            labels,
1205            sample_weights: None,
1206            metadata: HashMap::new(),
1207        };
1208
1209        assert_eq!(task.task_id, "test_task");
1210        assert_eq!(task.features.shape(), &[3, 2]);
1211        assert_eq!(task.labels.len(), 3);
1212    }
1213
1214    #[test]
1215    fn test_multi_task_ensemble_basic() {
1216        let config = MultiTaskEnsembleConfig::builder()
1217            .n_estimators_per_task(2)
1218            .sharing_strategy(TaskSharingStrategy::Independent)
1219            .build();
1220
1221        let ensemble = MultiTaskEnsembleRegressor::new(config);
1222
1223        // Test basic configuration
1224        assert_eq!(ensemble.config.n_estimators_per_task, 2);
1225        assert_eq!(
1226            ensemble.config.sharing_strategy,
1227            TaskSharingStrategy::Independent
1228        );
1229        // In untrained state, models should be None
1230        assert!(ensemble.task_models.is_none());
1231    }
1232
1233    #[test]
1234    fn test_task_similarity_calculation() {
1235        let config = MultiTaskEnsembleConfig::default();
1236        let ensemble = MultiTaskEnsembleRegressor::new(config);
1237
1238        let task1 = TaskData {
1239            task_id: "task1".to_string(),
1240            features: Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap(),
1241            labels: vec![1.0, 2.0],
1242            sample_weights: None,
1243            metadata: HashMap::new(),
1244        };
1245
1246        let task2 = TaskData {
1247            task_id: "task2".to_string(),
1248            features: Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap(),
1249            labels: vec![1.0, 2.0], // Same labels as task1
1250            sample_weights: None,
1251            metadata: HashMap::new(),
1252        };
1253
1254        let similarity = ensemble.correlation_similarity(&task1, &task2).unwrap();
1255        assert!((similarity - 1.0).abs() < 1e-10); // Should be perfectly correlated
1256    }
1257}