Skip to main content

sklears_ensemble/
multi_task.rs

1//! Multi-Task Ensemble Methods
2//!
3//! This module provides ensemble methods for multi-task learning, where multiple
4//! related learning tasks are solved jointly to improve generalization performance
5//! by leveraging information shared across tasks.
6
7use crate::bagging::BaggingClassifier;
8use crate::gradient_boosting::{
9    GradientBoostingConfig, GradientBoostingRegressor, TrainedGradientBoostingRegressor,
10};
11use scirs2_core::ndarray::{Array1, Array2};
12use sklears_core::{
13    error::Result as SklResult,
14    prelude::{Predict, SklearsError},
15    traits::{Estimator, Fit, Trained, Untrained},
16};
17use std::collections::HashMap;
18
19/// A trait that combines the `Estimator` and `Predict` traits.
20pub trait MultiTaskEstimator<C, E, F, X, Y>:
21    Estimator<Config = C, Error = E, Float = F> + Predict<X, Y>
22{
23}
24
25impl<T, C, E, F, X, Y> MultiTaskEstimator<C, E, F, X, Y> for T where
26    T: Estimator<Config = C, Error = E, Float = F> + Predict<X, Y>
27{
28}
29
30/// Configuration for multi-task ensemble learning
31#[derive(Debug, Clone)]
32pub struct MultiTaskEnsembleConfig {
33    /// Number of base estimators per task
34    pub n_estimators_per_task: usize,
35    /// Task sharing strategy
36    pub sharing_strategy: TaskSharingStrategy,
37    /// Task similarity metric for adaptive sharing
38    pub similarity_metric: TaskSimilarityMetric,
39    /// Minimum task similarity threshold for sharing
40    pub min_similarity_threshold: f64,
41    /// Task weighting strategy
42    pub task_weighting: TaskWeightingStrategy,
43    /// Whether to use task-specific feature selection
44    pub use_task_specific_features: bool,
45    /// Number of shared features across tasks
46    pub n_shared_features: Option<usize>,
47    /// Regularization strength for task sharing
48    pub sharing_regularization: f64,
49    /// Maximum depth for task hierarchy
50    pub max_task_depth: usize,
51    /// Cross-task validation strategy
52    pub cross_task_validation: CrossTaskValidation,
53}
54
55impl Default for MultiTaskEnsembleConfig {
56    fn default() -> Self {
57        Self {
58            n_estimators_per_task: 10,
59            sharing_strategy: TaskSharingStrategy::SharedRepresentation,
60            similarity_metric: TaskSimilarityMetric::CorrelationBased,
61            min_similarity_threshold: 0.3,
62            task_weighting: TaskWeightingStrategy::Uniform,
63            use_task_specific_features: true,
64            n_shared_features: None,
65            sharing_regularization: 0.1,
66            max_task_depth: 3,
67            cross_task_validation: CrossTaskValidation::LeaveOneTaskOut,
68        }
69    }
70}
71
72/// Strategies for sharing information between tasks
73#[derive(Debug, Clone, PartialEq)]
74pub enum TaskSharingStrategy {
75    /// No sharing between tasks
76    Independent,
77    /// Shared representation learning
78    SharedRepresentation,
79    /// Parameter sharing between tasks
80    ParameterSharing,
81    /// Hierarchical task relationships
82    HierarchicalSharing,
83    /// Adaptive sharing based on task similarity
84    AdaptiveSharing,
85    /// Multi-level sharing with different granularities
86    MultiLevelSharing,
87    /// Transfer learning between tasks
88    TransferLearning,
89}
90
91/// Metrics for measuring task similarity
92#[derive(Debug, Clone, PartialEq)]
93pub enum TaskSimilarityMetric {
94    /// Correlation-based similarity
95    CorrelationBased,
96    /// Feature importance similarity
97    FeatureImportanceSimilarity,
98    /// Model prediction similarity
99    PredictionSimilarity,
100    /// Data distribution similarity
101    DistributionSimilarity,
102    /// Gradient similarity
103    GradientSimilarity,
104    /// Task performance correlation
105    PerformanceCorrelation,
106}
107
108/// Strategies for weighting different tasks
109#[derive(Debug, Clone, PartialEq)]
110pub enum TaskWeightingStrategy {
111    /// Equal weight for all tasks
112    Uniform,
113    /// Weight based on task difficulty
114    DifficultyBased,
115    /// Weight based on task sample size
116    SampleSizeBased,
117    /// Weight based on task performance
118    PerformanceBased,
119    /// Adaptive weighting during training
120    AdaptiveWeighting,
121    /// Weight based on task importance
122    ImportanceBased,
123}
124
125/// Cross-task validation strategies
126#[derive(Debug, Clone, PartialEq)]
127pub enum CrossTaskValidation {
128    /// Leave one task out for validation
129    LeaveOneTaskOut,
130    /// Cross-validation within each task
131    WithinTaskCV,
132    /// Hierarchical cross-validation
133    HierarchicalCV,
134    /// Time-based validation for temporal tasks
135    TemporalCV,
136    /// Stratified validation across tasks
137    StratifiedCV,
138}
139
140/// Multi-task ensemble classifier
141pub struct MultiTaskEnsembleClassifier<State = Untrained> {
142    config: MultiTaskEnsembleConfig,
143    state: std::marker::PhantomData<State>,
144    // Fitted attributes - only populated after training
145    task_models: Option<HashMap<String, Vec<BaggingClassifier<Trained>>>>,
146    shared_models: Option<Vec<BaggingClassifier<Trained>>>,
147    task_similarities: Option<HashMap<(String, String), f64>>,
148    task_weights: Option<HashMap<String, f64>>,
149    feature_selector: Option<MultiTaskFeatureSelector>,
150    task_hierarchy: Option<TaskHierarchy>,
151}
152
153/// Multi-task ensemble regressor
154pub struct MultiTaskEnsembleRegressor<State = Untrained> {
155    config: MultiTaskEnsembleConfig,
156    state: std::marker::PhantomData<State>,
157    // Fitted attributes - only populated after training
158    task_models: Option<HashMap<String, Vec<TrainedGradientBoostingRegressor>>>,
159    shared_models: Option<Vec<TrainedGradientBoostingRegressor>>,
160    task_similarities: Option<HashMap<(String, String), f64>>,
161    task_weights: Option<HashMap<String, f64>>,
162    feature_selector: Option<MultiTaskFeatureSelector>,
163    task_hierarchy: Option<TaskHierarchy>,
164}
165
166/// Task hierarchy for hierarchical sharing
167#[derive(Debug, Clone)]
168pub struct TaskHierarchy {
169    /// Parent-child relationships between tasks
170    pub hierarchy: HashMap<String, Vec<String>>,
171    /// Task depth in the hierarchy
172    pub task_depths: HashMap<String, usize>,
173    /// Sharing weights based on hierarchy
174    pub hierarchy_weights: HashMap<(String, String), f64>,
175}
176
177/// Multi-task feature selector
178#[derive(Debug, Clone)]
179pub struct MultiTaskFeatureSelector {
180    /// Task-specific feature masks
181    pub task_feature_masks: HashMap<String, Vec<bool>>,
182    /// Shared feature mask
183    pub shared_feature_mask: Vec<bool>,
184    /// Feature importance scores per task
185    pub task_feature_importances: HashMap<String, Vec<f64>>,
186    /// Global feature importance scores
187    pub global_feature_importances: Vec<f64>,
188}
189
190/// Task-specific training data
191#[derive(Debug, Clone)]
192pub struct TaskData {
193    /// Task identifier
194    pub task_id: String,
195    /// Features for this task
196    pub features: Array2<f64>,
197    /// Labels for this task
198    pub labels: Vec<f64>,
199    /// Sample weights (optional)
200    pub sample_weights: Option<Vec<f64>>,
201    /// Task metadata
202    pub metadata: HashMap<String, String>,
203}
204
205/// Results from multi-task training
206#[derive(Debug, Clone)]
207pub struct MultiTaskTrainingResults {
208    /// Training metrics per task
209    pub task_metrics: HashMap<String, TaskMetrics>,
210    /// Cross-task transfer effects
211    pub transfer_effects: HashMap<(String, String), f64>,
212    /// Final task similarities
213    pub final_similarities: HashMap<(String, String), f64>,
214    /// Convergence information
215    pub convergence_info: ConvergenceInfo,
216}
217
218/// Metrics for individual tasks
219#[derive(Debug, Clone)]
220pub struct TaskMetrics {
221    /// Training accuracy/error
222    pub training_score: f64,
223    /// Validation accuracy/error
224    pub validation_score: f64,
225    /// Number of training samples
226    pub n_samples: usize,
227    /// Training time
228    pub training_time: f64,
229    /// Model complexity measure
230    pub complexity: f64,
231}
232
233/// Convergence information for multi-task training
234#[derive(Debug, Clone)]
235pub struct ConvergenceInfo {
236    /// Number of iterations to convergence
237    pub n_iterations: usize,
238    /// Final loss value
239    pub final_loss: f64,
240    /// Convergence tolerance achieved
241    pub tolerance_achieved: f64,
242    /// Whether convergence was reached
243    pub converged: bool,
244}
245
246impl MultiTaskEnsembleConfig {
247    pub fn builder() -> MultiTaskEnsembleConfigBuilder {
248        MultiTaskEnsembleConfigBuilder::default()
249    }
250}
251
252#[derive(Default)]
253pub struct MultiTaskEnsembleConfigBuilder {
254    config: MultiTaskEnsembleConfig,
255}
256
257impl MultiTaskEnsembleConfigBuilder {
258    pub fn n_estimators_per_task(mut self, n_estimators: usize) -> Self {
259        self.config.n_estimators_per_task = n_estimators;
260        self
261    }
262
263    pub fn sharing_strategy(mut self, strategy: TaskSharingStrategy) -> Self {
264        self.config.sharing_strategy = strategy;
265        self
266    }
267
268    pub fn similarity_metric(mut self, metric: TaskSimilarityMetric) -> Self {
269        self.config.similarity_metric = metric;
270        self
271    }
272
273    pub fn min_similarity_threshold(mut self, threshold: f64) -> Self {
274        self.config.min_similarity_threshold = threshold;
275        self
276    }
277
278    pub fn task_weighting(mut self, weighting: TaskWeightingStrategy) -> Self {
279        self.config.task_weighting = weighting;
280        self
281    }
282
283    pub fn use_task_specific_features(mut self, use_specific: bool) -> Self {
284        self.config.use_task_specific_features = use_specific;
285        self
286    }
287
288    pub fn sharing_regularization(mut self, regularization: f64) -> Self {
289        self.config.sharing_regularization = regularization;
290        self
291    }
292
293    pub fn cross_task_validation(mut self, validation: CrossTaskValidation) -> Self {
294        self.config.cross_task_validation = validation;
295        self
296    }
297
298    pub fn build(self) -> MultiTaskEnsembleConfig {
299        self.config
300    }
301}
302
303impl MultiTaskEnsembleRegressor {
304    pub fn new(config: MultiTaskEnsembleConfig) -> Self {
305        Self {
306            config,
307            state: std::marker::PhantomData,
308            task_models: None,
309            shared_models: None,
310            task_similarities: None,
311            task_weights: None,
312            feature_selector: None,
313            task_hierarchy: None,
314        }
315    }
316
317    pub fn builder() -> MultiTaskEnsembleRegressorBuilder {
318        MultiTaskEnsembleRegressorBuilder::new()
319    }
320
321    /// Fit the multi-task ensemble on multiple tasks
322    pub fn fit_tasks(
323        mut self,
324        tasks: &[TaskData],
325    ) -> SklResult<MultiTaskEnsembleRegressor<Trained>> {
326        if tasks.is_empty() {
327            return Err(SklearsError::InvalidInput("No tasks provided".to_string()));
328        }
329
330        // Initialize task weights
331        self.initialize_task_weights(tasks)?;
332
333        // Build task hierarchy if needed
334        if matches!(
335            self.config.sharing_strategy,
336            TaskSharingStrategy::HierarchicalSharing
337        ) {
338            self.build_task_hierarchy(tasks)?;
339        }
340
341        // Initialize feature selector
342        if self.config.use_task_specific_features {
343            self.initialize_feature_selector(tasks)?;
344        }
345
346        // Compute task similarities
347        self.compute_task_similarities(tasks)?;
348
349        // Train models based on sharing strategy
350        let training_results = match self.config.sharing_strategy {
351            TaskSharingStrategy::Independent => self.train_independent_tasks(tasks)?,
352            TaskSharingStrategy::SharedRepresentation => self.train_shared_representation(tasks)?,
353            TaskSharingStrategy::ParameterSharing => self.train_parameter_sharing(tasks)?,
354            TaskSharingStrategy::HierarchicalSharing => self.train_hierarchical_sharing(tasks)?,
355            TaskSharingStrategy::AdaptiveSharing => self.train_adaptive_sharing(tasks)?,
356            _ => self.train_independent_tasks(tasks)?, // Default fallback
357        };
358
359        let fitted_ensemble = MultiTaskEnsembleRegressor::<Trained> {
360            config: self.config,
361            state: std::marker::PhantomData,
362            task_models: self.task_models,
363            shared_models: self.shared_models,
364            task_similarities: self.task_similarities,
365            task_weights: self.task_weights,
366            feature_selector: self.feature_selector,
367            task_hierarchy: self.task_hierarchy,
368        };
369
370        Ok(fitted_ensemble)
371    }
372
373    /// Initialize task weights based on configuration
374    fn initialize_task_weights(&mut self, tasks: &[TaskData]) -> SklResult<()> {
375        let mut weights = HashMap::new();
376
377        match self.config.task_weighting {
378            TaskWeightingStrategy::Uniform => {
379                let weight = 1.0 / tasks.len() as f64;
380                for task in tasks {
381                    weights.insert(task.task_id.clone(), weight);
382                }
383            }
384            TaskWeightingStrategy::SampleSizeBased => {
385                let total_samples: usize = tasks.iter().map(|t| t.features.shape()[0]).sum();
386                for task in tasks {
387                    let weight = task.features.shape()[0] as f64 / total_samples as f64;
388                    weights.insert(task.task_id.clone(), weight);
389                }
390            }
391            _ => {
392                // Default to uniform for now
393                let weight = 1.0 / tasks.len() as f64;
394                for task in tasks {
395                    weights.insert(task.task_id.clone(), weight);
396                }
397            }
398        }
399
400        self.task_weights = Some(weights);
401        Ok(())
402    }
403
404    /// Build task hierarchy for hierarchical sharing
405    fn build_task_hierarchy(&mut self, tasks: &[TaskData]) -> SklResult<()> {
406        let mut hierarchy = HashMap::new();
407        let mut task_depths = HashMap::new();
408        let mut hierarchy_weights = HashMap::new();
409
410        // Simple hierarchical construction based on task similarities
411        // In practice, this could be based on domain knowledge or learned
412        for (i, task1) in tasks.iter().enumerate() {
413            task_depths.insert(task1.task_id.clone(), 0);
414            hierarchy.insert(task1.task_id.clone(), Vec::new());
415
416            for (j, task2) in tasks.iter().enumerate() {
417                if i != j {
418                    // Set hierarchical weight based on some similarity measure
419                    let weight = 1.0 / (i.abs_diff(j) + 1) as f64;
420                    hierarchy_weights
421                        .insert((task1.task_id.clone(), task2.task_id.clone()), weight);
422                }
423            }
424        }
425
426        self.task_hierarchy = Some(TaskHierarchy {
427            hierarchy,
428            task_depths,
429            hierarchy_weights,
430        });
431
432        Ok(())
433    }
434
435    /// Initialize feature selector for task-specific features
436    fn initialize_feature_selector(&mut self, tasks: &[TaskData]) -> SklResult<()> {
437        if tasks.is_empty() {
438            return Ok(());
439        }
440
441        let n_features = tasks[0].features.shape()[1];
442        let mut task_feature_masks = HashMap::new();
443        let mut task_feature_importances = HashMap::new();
444
445        // For now, use simple feature selection based on variance
446        for task in tasks {
447            let mut feature_mask = vec![true; n_features];
448            let mut feature_importances = vec![0.0; n_features];
449
450            // Calculate feature variances as a simple importance measure
451            for j in 0..n_features {
452                let column: Vec<f64> = (0..task.features.shape()[0])
453                    .map(|i| task.features[[i, j]])
454                    .collect();
455                let mean = column.iter().sum::<f64>() / column.len() as f64;
456                let variance =
457                    column.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / column.len() as f64;
458
459                feature_importances[j] = variance;
460                feature_mask[j] = variance > 1e-8; // Keep features with non-zero variance
461            }
462
463            task_feature_masks.insert(task.task_id.clone(), feature_mask);
464            task_feature_importances.insert(task.task_id.clone(), feature_importances);
465        }
466
467        // Global feature importance (average across tasks)
468        let mut global_feature_importances = vec![0.0; n_features];
469        for j in 0..n_features {
470            let sum: f64 = task_feature_importances
471                .values()
472                .map(|importances| importances[j])
473                .sum();
474            global_feature_importances[j] = sum / tasks.len() as f64;
475        }
476
477        // Shared feature mask (features important across multiple tasks)
478        let shared_feature_mask: Vec<bool> = (0..n_features)
479            .map(|j| {
480                let important_count = task_feature_importances
481                    .values()
482                    .filter(|importances| importances[j] > global_feature_importances[j] * 0.5)
483                    .count();
484                important_count >= tasks.len() / 2
485            })
486            .collect();
487
488        self.feature_selector = Some(MultiTaskFeatureSelector {
489            task_feature_masks,
490            shared_feature_mask,
491            task_feature_importances,
492            global_feature_importances,
493        });
494
495        Ok(())
496    }
497
498    /// Compute similarities between tasks
499    fn compute_task_similarities(&mut self, tasks: &[TaskData]) -> SklResult<()> {
500        let mut similarities = HashMap::new();
501
502        for (i, task1) in tasks.iter().enumerate() {
503            for task2 in tasks.iter().skip(i + 1) {
504                let similarity = self.calculate_task_similarity(task1, task2)?;
505                similarities.insert((task1.task_id.clone(), task2.task_id.clone()), similarity);
506                similarities.insert((task2.task_id.clone(), task1.task_id.clone()), similarity);
507            }
508        }
509
510        self.task_similarities = Some(similarities);
511        Ok(())
512    }
513
514    /// Calculate similarity between two tasks
515    fn calculate_task_similarity(&self, task1: &TaskData, task2: &TaskData) -> SklResult<f64> {
516        match self.config.similarity_metric {
517            TaskSimilarityMetric::CorrelationBased => self.correlation_similarity(task1, task2),
518            TaskSimilarityMetric::DistributionSimilarity => {
519                self.distribution_similarity(task1, task2)
520            }
521            _ => {
522                // Default to correlation-based similarity
523                self.correlation_similarity(task1, task2)
524            }
525        }
526    }
527
528    /// Calculate correlation-based similarity between tasks
529    fn correlation_similarity(&self, task1: &TaskData, task2: &TaskData) -> SklResult<f64> {
530        // Simple correlation between target variables if they have the same length
531        if task1.labels.len() != task2.labels.len() {
532            return Ok(0.0); // No similarity if different lengths
533        }
534
535        let n = task1.labels.len();
536        if n < 2 {
537            return Ok(0.0);
538        }
539
540        let mean1 = task1.labels.iter().sum::<f64>() / n as f64;
541        let mean2 = task2.labels.iter().sum::<f64>() / n as f64;
542
543        let mut numerator = 0.0;
544        let mut denom1 = 0.0;
545        let mut denom2 = 0.0;
546
547        for i in 0..n {
548            let diff1 = task1.labels[i] - mean1;
549            let diff2 = task2.labels[i] - mean2;
550            numerator += diff1 * diff2;
551            denom1 += diff1 * diff1;
552            denom2 += diff2 * diff2;
553        }
554
555        if denom1 * denom2 > 0.0 {
556            Ok(numerator / (denom1 * denom2).sqrt())
557        } else {
558            Ok(0.0)
559        }
560    }
561
562    /// Calculate distribution similarity between tasks
563    fn distribution_similarity(&self, task1: &TaskData, task2: &TaskData) -> SklResult<f64> {
564        // Simple approach: compare feature means and variances
565        let n_features1 = task1.features.shape()[1];
566        let n_features2 = task2.features.shape()[1];
567
568        if n_features1 != n_features2 {
569            return Ok(0.0);
570        }
571
572        let mut similarity_sum = 0.0;
573
574        for j in 0..n_features1 {
575            let col1: Vec<f64> = (0..task1.features.shape()[0])
576                .map(|i| task1.features[[i, j]])
577                .collect();
578            let col2: Vec<f64> = (0..task2.features.shape()[0])
579                .map(|i| task2.features[[i, j]])
580                .collect();
581
582            let mean1 = col1.iter().sum::<f64>() / col1.len() as f64;
583            let mean2 = col2.iter().sum::<f64>() / col2.len() as f64;
584
585            let var1 = col1.iter().map(|&x| (x - mean1).powi(2)).sum::<f64>() / col1.len() as f64;
586            let var2 = col2.iter().map(|&x| (x - mean2).powi(2)).sum::<f64>() / col2.len() as f64;
587
588            // Similarity based on mean and variance differences
589            let mean_sim = 1.0 - (mean1 - mean2).abs() / (mean1.abs() + mean2.abs() + 1e-8);
590            let var_sim = 1.0 - (var1 - var2).abs() / (var1 + var2 + 1e-8);
591
592            similarity_sum += (mean_sim + var_sim) / 2.0;
593        }
594
595        Ok(similarity_sum / n_features1 as f64)
596    }
597
598    /// Train independent models for each task
599    fn train_independent_tasks(
600        &mut self,
601        tasks: &[TaskData],
602    ) -> SklResult<MultiTaskTrainingResults> {
603        let mut task_metrics = HashMap::new();
604
605        // Initialize task_models
606        if self.task_models.is_none() {
607            self.task_models = Some(HashMap::new());
608        }
609
610        for task in tasks {
611            let mut models = Vec::new();
612
613            for _ in 0..self.config.n_estimators_per_task {
614                let gb_config = GradientBoostingConfig {
615                    n_estimators: 50,
616                    learning_rate: 0.1,
617                    max_depth: 6,
618                    ..Default::default()
619                };
620
621                let y_array = Array1::from_vec(task.labels.clone());
622                let model = GradientBoostingRegressor::builder()
623                    .n_estimators(50)
624                    .learning_rate(0.1)
625                    .max_depth(6)
626                    .build()
627                    .fit(&task.features, &y_array)?;
628
629                models.push(model);
630            }
631
632            // Calculate task metrics
633            let predictions = self.predict_task_ensemble(&models, &task.features)?;
634            let mse = self.calculate_mse(&predictions, &task.labels);
635
636            task_metrics.insert(
637                task.task_id.clone(),
638                TaskMetrics {
639                    training_score: mse,
640                    validation_score: mse, // Would be different in practice
641                    n_samples: task.features.shape()[0],
642                    training_time: 0.0, // Would measure actual time
643                    complexity: models.len() as f64,
644                },
645            );
646
647            self.task_models
648                .as_mut()
649                .expect("operation should succeed")
650                .insert(task.task_id.clone(), models);
651        }
652
653        Ok(MultiTaskTrainingResults {
654            task_metrics,
655            transfer_effects: HashMap::new(),
656            final_similarities: self.task_similarities.clone().unwrap_or_default(),
657            convergence_info: ConvergenceInfo {
658                n_iterations: 1,
659                final_loss: 0.0,
660                tolerance_achieved: 0.0,
661                converged: true,
662            },
663        })
664    }
665
666    /// Train with shared representation learning
667    fn train_shared_representation(
668        &mut self,
669        tasks: &[TaskData],
670    ) -> SklResult<MultiTaskTrainingResults> {
671        // First train shared models on combined data
672        let combined_data = self.combine_task_data(tasks)?;
673
674        // Initialize shared_models
675        if self.shared_models.is_none() {
676            self.shared_models = Some(Vec::new());
677        }
678
679        for _ in 0..self.config.n_estimators_per_task {
680            let gb_config = GradientBoostingConfig {
681                n_estimators: 30,
682                learning_rate: 0.1,
683                max_depth: 4,
684                ..Default::default()
685            };
686
687            let y_array = Array1::from_vec(combined_data.labels.clone());
688            let shared_model = GradientBoostingRegressor::builder()
689                .n_estimators(30)
690                .learning_rate(0.1)
691                .max_depth(4)
692                .build()
693                .fit(&combined_data.features, &y_array)?;
694
695            self.shared_models
696                .as_mut()
697                .expect("operation should succeed")
698                .push(shared_model);
699        }
700
701        // Then train task-specific models
702        self.train_independent_tasks(tasks)
703    }
704
705    /// Train with parameter sharing between similar tasks
706    fn train_parameter_sharing(
707        &mut self,
708        tasks: &[TaskData],
709    ) -> SklResult<MultiTaskTrainingResults> {
710        // Group similar tasks
711        let task_groups = self.group_similar_tasks(tasks)?;
712
713        let mut task_metrics = HashMap::new();
714
715        for group in task_groups {
716            // Train shared models for this group
717            let group_data = self.combine_group_data(tasks, &group)?;
718
719            let mut group_models = Vec::new();
720            for _ in 0..self.config.n_estimators_per_task {
721                let gb_config = GradientBoostingConfig {
722                    n_estimators: 40,
723                    learning_rate: 0.1,
724                    max_depth: 5,
725                    ..Default::default()
726                };
727
728                let y_array = Array1::from_vec(group_data.labels.clone());
729                let model = GradientBoostingRegressor::builder()
730                    .n_estimators(40)
731                    .learning_rate(0.1)
732                    .max_depth(5)
733                    .build()
734                    .fit(&group_data.features, &y_array)?;
735
736                group_models.push(model);
737            }
738
739            // Calculate metrics for tasks in this group and create separate model sets
740            for task_id in &group {
741                let task = tasks
742                    .iter()
743                    .find(|t| &t.task_id == task_id)
744                    .expect("operation should succeed");
745
746                // Calculate metrics for this task
747                let predictions = self.predict_task_ensemble(&group_models, &task.features)?;
748                let mse = self.calculate_mse(&predictions, &task.labels);
749
750                task_metrics.insert(
751                    task_id.clone(),
752                    TaskMetrics {
753                        training_score: mse,
754                        validation_score: mse,
755                        n_samples: task.features.shape()[0],
756                        training_time: 0.0,
757                        complexity: group_models.len() as f64,
758                    },
759                );
760
761                // For each task, train a separate set of models with the same configuration
762                let mut task_models = Vec::new();
763                for _ in 0..self.config.n_estimators_per_task {
764                    let y_array = Array1::from_vec(group_data.labels.clone());
765                    let model = GradientBoostingRegressor::builder()
766                        .n_estimators(40)
767                        .learning_rate(0.1)
768                        .max_depth(5)
769                        .build()
770                        .fit(&group_data.features, &y_array)?;
771                    task_models.push(model);
772                }
773
774                self.task_models
775                    .as_mut()
776                    .expect("operation should succeed")
777                    .insert(task_id.clone(), task_models);
778            }
779        }
780
781        Ok(MultiTaskTrainingResults {
782            task_metrics,
783            transfer_effects: HashMap::new(),
784            final_similarities: self.task_similarities.clone().unwrap_or_default(),
785            convergence_info: ConvergenceInfo {
786                n_iterations: 1,
787                final_loss: 0.0,
788                tolerance_achieved: 0.0,
789                converged: true,
790            },
791        })
792    }
793
794    /// Train with hierarchical sharing
795    fn train_hierarchical_sharing(
796        &mut self,
797        tasks: &[TaskData],
798    ) -> SklResult<MultiTaskTrainingResults> {
799        // For now, delegate to parameter sharing
800        // In practice, this would implement hierarchical relationships
801        self.train_parameter_sharing(tasks)
802    }
803
804    /// Train with adaptive sharing based on task similarities
805    fn train_adaptive_sharing(
806        &mut self,
807        tasks: &[TaskData],
808    ) -> SklResult<MultiTaskTrainingResults> {
809        let mut task_metrics = HashMap::new();
810
811        for task in tasks {
812            let mut models = Vec::new();
813
814            // Find similar tasks for this task
815            let similar_tasks = self.find_similar_tasks(&task.task_id, tasks);
816
817            if similar_tasks.len() > 1 {
818                // Train with data from similar tasks
819                let combined_data = self.combine_similar_task_data(tasks, &similar_tasks)?;
820
821                for _ in 0..self.config.n_estimators_per_task {
822                    let gb_config = GradientBoostingConfig {
823                        n_estimators: 50,
824                        learning_rate: 0.1,
825                        max_depth: 6,
826                        ..Default::default()
827                    };
828
829                    let y_array = Array1::from_vec(combined_data.labels.clone());
830                    let model = GradientBoostingRegressor::builder()
831                        .n_estimators(50)
832                        .learning_rate(0.1)
833                        .max_depth(6)
834                        .build()
835                        .fit(&combined_data.features, &y_array)?;
836
837                    models.push(model);
838                }
839            } else {
840                // Train independently if no similar tasks
841                for _ in 0..self.config.n_estimators_per_task {
842                    let gb_config = GradientBoostingConfig {
843                        n_estimators: 50,
844                        learning_rate: 0.1,
845                        max_depth: 6,
846                        ..Default::default()
847                    };
848
849                    let y_array = Array1::from_vec(task.labels.clone());
850                    let model = GradientBoostingRegressor::builder()
851                        .n_estimators(50)
852                        .learning_rate(0.1)
853                        .max_depth(6)
854                        .build()
855                        .fit(&task.features, &y_array)?;
856
857                    models.push(model);
858                }
859            }
860
861            // Calculate metrics
862            let predictions = self.predict_task_ensemble(&models, &task.features)?;
863            let mse = self.calculate_mse(&predictions, &task.labels);
864
865            task_metrics.insert(
866                task.task_id.clone(),
867                TaskMetrics {
868                    training_score: mse,
869                    validation_score: mse,
870                    n_samples: task.features.shape()[0],
871                    training_time: 0.0,
872                    complexity: models.len() as f64,
873                },
874            );
875
876            self.task_models
877                .as_mut()
878                .expect("operation should succeed")
879                .insert(task.task_id.clone(), models);
880        }
881
882        Ok(MultiTaskTrainingResults {
883            task_metrics,
884            transfer_effects: HashMap::new(),
885            final_similarities: self.task_similarities.clone().unwrap_or_default(),
886            convergence_info: ConvergenceInfo {
887                n_iterations: 1,
888                final_loss: 0.0,
889                tolerance_achieved: 0.0,
890                converged: true,
891            },
892        })
893    }
894
895    /// Combine data from multiple tasks
896    fn combine_task_data(&self, tasks: &[TaskData]) -> SklResult<TaskData> {
897        if tasks.is_empty() {
898            return Err(SklearsError::InvalidInput(
899                "No tasks to combine".to_string(),
900            ));
901        }
902
903        let total_samples: usize = tasks.iter().map(|t| t.features.shape()[0]).sum();
904        let n_features = tasks[0].features.shape()[1];
905
906        let mut combined_features = Vec::with_capacity(total_samples * n_features);
907        let mut combined_labels = Vec::with_capacity(total_samples);
908
909        for task in tasks {
910            for i in 0..task.features.shape()[0] {
911                for j in 0..n_features {
912                    combined_features.push(task.features[[i, j]]);
913                }
914                combined_labels.push(task.labels[i]);
915            }
916        }
917
918        let features = Array2::from_shape_vec((total_samples, n_features), combined_features)?;
919
920        Ok(TaskData {
921            task_id: "combined".to_string(),
922            features,
923            labels: combined_labels,
924            sample_weights: None,
925            metadata: HashMap::new(),
926        })
927    }
928
929    /// Group similar tasks together
930    fn group_similar_tasks(&self, tasks: &[TaskData]) -> SklResult<Vec<Vec<String>>> {
931        let mut groups = Vec::new();
932        let mut assigned = vec![false; tasks.len()];
933
934        for (i, task) in tasks.iter().enumerate() {
935            if assigned[i] {
936                continue;
937            }
938
939            let mut group = vec![task.task_id.clone()];
940            assigned[i] = true;
941
942            // Find similar tasks
943            for (j, other_task) in tasks.iter().enumerate() {
944                if i != j && !assigned[j] {
945                    let similarity = self
946                        .task_similarities
947                        .as_ref()
948                        .and_then(|similarities| {
949                            similarities.get(&(task.task_id.clone(), other_task.task_id.clone()))
950                        })
951                        .copied()
952                        .unwrap_or(0.0);
953
954                    if similarity >= self.config.min_similarity_threshold {
955                        group.push(other_task.task_id.clone());
956                        assigned[j] = true;
957                    }
958                }
959            }
960
961            groups.push(group);
962        }
963
964        Ok(groups)
965    }
966
967    /// Combine data from a group of tasks
968    fn combine_group_data(&self, tasks: &[TaskData], group: &[String]) -> SklResult<TaskData> {
969        let group_tasks: Vec<TaskData> = tasks
970            .iter()
971            .filter(|t| group.contains(&t.task_id))
972            .cloned()
973            .collect();
974
975        self.combine_task_data(&group_tasks)
976    }
977
978    /// Find tasks similar to a given task
979    fn find_similar_tasks(&self, task_id: &str, tasks: &[TaskData]) -> Vec<String> {
980        let mut similar_tasks = vec![task_id.to_string()];
981
982        for task in tasks {
983            if task.task_id != task_id {
984                let similarity = self
985                    .task_similarities
986                    .as_ref()
987                    .and_then(|similarities| {
988                        similarities.get(&(task_id.to_string(), task.task_id.clone()))
989                    })
990                    .copied()
991                    .unwrap_or(0.0);
992
993                if similarity >= self.config.min_similarity_threshold {
994                    similar_tasks.push(task.task_id.clone());
995                }
996            }
997        }
998
999        similar_tasks
1000    }
1001
1002    /// Combine data from similar tasks
1003    fn combine_similar_task_data(
1004        &self,
1005        tasks: &[TaskData],
1006        similar_task_ids: &[String],
1007    ) -> SklResult<TaskData> {
1008        let similar_tasks: Vec<TaskData> = tasks
1009            .iter()
1010            .filter(|t| similar_task_ids.contains(&t.task_id))
1011            .cloned()
1012            .collect();
1013
1014        self.combine_task_data(&similar_tasks)
1015    }
1016
1017    /// Predict using task ensemble
1018    fn predict_task_ensemble(
1019        &self,
1020        models: &[TrainedGradientBoostingRegressor],
1021        X: &Array2<f64>,
1022    ) -> SklResult<Vec<f64>> {
1023        if models.is_empty() {
1024            return Err(SklearsError::InvalidInput(
1025                "No models in ensemble".to_string(),
1026            ));
1027        }
1028
1029        let mut predictions = vec![0.0; X.shape()[0]];
1030
1031        for model in models {
1032            let pred = model.predict(X)?;
1033            for (i, &p) in pred.iter().enumerate() {
1034                predictions[i] += p;
1035            }
1036        }
1037
1038        // Average predictions
1039        for p in &mut predictions {
1040            *p /= models.len() as f64;
1041        }
1042
1043        Ok(predictions)
1044    }
1045
1046    /// Calculate mean squared error
1047    fn calculate_mse(&self, predictions: &[f64], targets: &[f64]) -> f64 {
1048        if predictions.len() != targets.len() {
1049            return f64::INFINITY;
1050        }
1051
1052        let sum_squared_error: f64 = predictions
1053            .iter()
1054            .zip(targets.iter())
1055            .map(|(&p, &t)| (p - t).powi(2))
1056            .sum();
1057
1058        sum_squared_error / predictions.len() as f64
1059    }
1060}
1061
1062impl MultiTaskEnsembleRegressor<Trained> {
1063    /// Predict for a specific task
1064    pub fn predict_task(&self, task_id: &str, X: &Array2<f64>) -> SklResult<Vec<f64>> {
1065        // Use task-specific models if available
1066        if let Some(models) = self.task_models.as_ref().and_then(|m| m.get(task_id)) {
1067            let mut predictions = self.predict_task_ensemble(models, X)?;
1068
1069            // Add shared model predictions if available
1070            if let Some(shared_models) = self.shared_models.as_ref() {
1071                if !shared_models.is_empty() {
1072                    let shared_predictions = self.predict_task_ensemble(shared_models, X)?;
1073                    for (i, &shared_pred) in shared_predictions.iter().enumerate() {
1074                        predictions[i] = 0.7 * predictions[i] + 0.3 * shared_pred;
1075                    }
1076                }
1077            }
1078
1079            Ok(predictions)
1080        } else {
1081            Err(SklearsError::InvalidInput(format!(
1082                "Task '{}' not found",
1083                task_id
1084            )))
1085        }
1086    }
1087
1088    /// Predict using task ensemble (internal method)
1089    fn predict_task_ensemble(
1090        &self,
1091        models: &[TrainedGradientBoostingRegressor],
1092        X: &Array2<f64>,
1093    ) -> SklResult<Vec<f64>> {
1094        if models.is_empty() {
1095            return Err(SklearsError::InvalidInput(
1096                "No models in ensemble".to_string(),
1097            ));
1098        }
1099
1100        let mut predictions = vec![0.0; X.shape()[0]];
1101
1102        for model in models {
1103            let pred = model.predict(X)?;
1104            for (i, &p) in pred.iter().enumerate() {
1105                predictions[i] += p;
1106            }
1107        }
1108
1109        // Average predictions
1110        for p in &mut predictions {
1111            *p /= models.len() as f64;
1112        }
1113
1114        Ok(predictions)
1115    }
1116
1117    /// Get task similarities
1118    pub fn get_task_similarities(&self) -> &HashMap<(String, String), f64> {
1119        self.task_similarities.as_ref().expect("Model is trained")
1120    }
1121
1122    /// Get task weights
1123    pub fn get_task_weights(&self) -> &HashMap<String, f64> {
1124        self.task_weights.as_ref().expect("Model is trained")
1125    }
1126}
1127
1128pub struct MultiTaskEnsembleRegressorBuilder {
1129    config: MultiTaskEnsembleConfig,
1130}
1131
1132impl Default for MultiTaskEnsembleRegressorBuilder {
1133    fn default() -> Self {
1134        Self::new()
1135    }
1136}
1137
1138impl MultiTaskEnsembleRegressorBuilder {
1139    pub fn new() -> Self {
1140        Self {
1141            config: MultiTaskEnsembleConfig::default(),
1142        }
1143    }
1144
1145    pub fn config(mut self, config: MultiTaskEnsembleConfig) -> Self {
1146        self.config = config;
1147        self
1148    }
1149
1150    pub fn n_estimators_per_task(mut self, n_estimators: usize) -> Self {
1151        self.config.n_estimators_per_task = n_estimators;
1152        self
1153    }
1154
1155    pub fn sharing_strategy(mut self, strategy: TaskSharingStrategy) -> Self {
1156        self.config.sharing_strategy = strategy;
1157        self
1158    }
1159
1160    pub fn min_similarity_threshold(mut self, threshold: f64) -> Self {
1161        self.config.min_similarity_threshold = threshold;
1162        self
1163    }
1164
1165    pub fn build(self) -> MultiTaskEnsembleRegressor {
1166        MultiTaskEnsembleRegressor::new(self.config)
1167    }
1168}
1169
1170impl Estimator for MultiTaskEnsembleRegressor {
1171    type Config = MultiTaskEnsembleConfig;
1172    type Error = SklearsError;
1173    type Float = f64;
1174
1175    fn config(&self) -> &Self::Config {
1176        &self.config
1177    }
1178}
1179
1180#[allow(non_snake_case)]
1181#[cfg(test)]
1182mod tests {
1183    use super::*;
1184    use scirs2_core::ndarray::Array2;
1185
1186    #[test]
1187    fn test_multi_task_config() {
1188        let config = MultiTaskEnsembleConfig::builder()
1189            .n_estimators_per_task(5)
1190            .sharing_strategy(TaskSharingStrategy::SharedRepresentation)
1191            .min_similarity_threshold(0.5)
1192            .build();
1193
1194        assert_eq!(config.n_estimators_per_task, 5);
1195        assert_eq!(
1196            config.sharing_strategy,
1197            TaskSharingStrategy::SharedRepresentation
1198        );
1199        assert_eq!(config.min_similarity_threshold, 0.5);
1200    }
1201
1202    #[test]
1203    fn test_task_data_creation() {
1204        let features = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1205            .expect("shape and data length should match");
1206        let labels = vec![1.0, 2.0, 3.0];
1207
1208        let task = TaskData {
1209            task_id: "test_task".to_string(),
1210            features,
1211            labels,
1212            sample_weights: None,
1213            metadata: HashMap::new(),
1214        };
1215
1216        assert_eq!(task.task_id, "test_task");
1217        assert_eq!(task.features.shape(), &[3, 2]);
1218        assert_eq!(task.labels.len(), 3);
1219    }
1220
1221    #[test]
1222    fn test_multi_task_ensemble_basic() {
1223        let config = MultiTaskEnsembleConfig::builder()
1224            .n_estimators_per_task(2)
1225            .sharing_strategy(TaskSharingStrategy::Independent)
1226            .build();
1227
1228        let ensemble = MultiTaskEnsembleRegressor::new(config);
1229
1230        // Test basic configuration
1231        assert_eq!(ensemble.config.n_estimators_per_task, 2);
1232        assert_eq!(
1233            ensemble.config.sharing_strategy,
1234            TaskSharingStrategy::Independent
1235        );
1236        // In untrained state, models should be None
1237        assert!(ensemble.task_models.is_none());
1238    }
1239
1240    #[test]
1241    fn test_task_similarity_calculation() {
1242        let config = MultiTaskEnsembleConfig::default();
1243        let ensemble = MultiTaskEnsembleRegressor::new(config);
1244
1245        let task1 = TaskData {
1246            task_id: "task1".to_string(),
1247            features: Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0])
1248                .expect("shape and data length should match"),
1249            labels: vec![1.0, 2.0],
1250            sample_weights: None,
1251            metadata: HashMap::new(),
1252        };
1253
1254        let task2 = TaskData {
1255            task_id: "task2".to_string(),
1256            features: Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0])
1257                .expect("shape and data length should match"),
1258            labels: vec![1.0, 2.0], // Same labels as task1
1259            sample_weights: None,
1260            metadata: HashMap::new(),
1261        };
1262
1263        let similarity = ensemble
1264            .correlation_similarity(&task1, &task2)
1265            .expect("operation should succeed");
1266        assert!((similarity - 1.0).abs() < 1e-10); // Should be perfectly correlated
1267    }
1268}