1use crate::bagging::BaggingClassifier;
8use crate::gradient_boosting::{
9 GradientBoostingConfig, GradientBoostingRegressor, TrainedGradientBoostingRegressor,
10};
11use scirs2_core::ndarray::{Array1, Array2};
12use sklears_core::{
13 error::Result as SklResult,
14 prelude::{Predict, SklearsError},
15 traits::{Estimator, Fit, Trained, Untrained},
16};
17use std::collections::HashMap;
18
19pub trait MultiTaskEstimator<C, E, F, X, Y>:
21 Estimator<Config = C, Error = E, Float = F> + Predict<X, Y>
22{
23}
24
25impl<T, C, E, F, X, Y> MultiTaskEstimator<C, E, F, X, Y> for T where
26 T: Estimator<Config = C, Error = E, Float = F> + Predict<X, Y>
27{
28}
29
30#[derive(Debug, Clone)]
32pub struct MultiTaskEnsembleConfig {
33 pub n_estimators_per_task: usize,
35 pub sharing_strategy: TaskSharingStrategy,
37 pub similarity_metric: TaskSimilarityMetric,
39 pub min_similarity_threshold: f64,
41 pub task_weighting: TaskWeightingStrategy,
43 pub use_task_specific_features: bool,
45 pub n_shared_features: Option<usize>,
47 pub sharing_regularization: f64,
49 pub max_task_depth: usize,
51 pub cross_task_validation: CrossTaskValidation,
53}
54
55impl Default for MultiTaskEnsembleConfig {
56 fn default() -> Self {
57 Self {
58 n_estimators_per_task: 10,
59 sharing_strategy: TaskSharingStrategy::SharedRepresentation,
60 similarity_metric: TaskSimilarityMetric::CorrelationBased,
61 min_similarity_threshold: 0.3,
62 task_weighting: TaskWeightingStrategy::Uniform,
63 use_task_specific_features: true,
64 n_shared_features: None,
65 sharing_regularization: 0.1,
66 max_task_depth: 3,
67 cross_task_validation: CrossTaskValidation::LeaveOneTaskOut,
68 }
69 }
70}
71
72#[derive(Debug, Clone, PartialEq)]
74pub enum TaskSharingStrategy {
75 Independent,
77 SharedRepresentation,
79 ParameterSharing,
81 HierarchicalSharing,
83 AdaptiveSharing,
85 MultiLevelSharing,
87 TransferLearning,
89}
90
91#[derive(Debug, Clone, PartialEq)]
93pub enum TaskSimilarityMetric {
94 CorrelationBased,
96 FeatureImportanceSimilarity,
98 PredictionSimilarity,
100 DistributionSimilarity,
102 GradientSimilarity,
104 PerformanceCorrelation,
106}
107
108#[derive(Debug, Clone, PartialEq)]
110pub enum TaskWeightingStrategy {
111 Uniform,
113 DifficultyBased,
115 SampleSizeBased,
117 PerformanceBased,
119 AdaptiveWeighting,
121 ImportanceBased,
123}
124
125#[derive(Debug, Clone, PartialEq)]
127pub enum CrossTaskValidation {
128 LeaveOneTaskOut,
130 WithinTaskCV,
132 HierarchicalCV,
134 TemporalCV,
136 StratifiedCV,
138}
139
140pub struct MultiTaskEnsembleClassifier<State = Untrained> {
142 config: MultiTaskEnsembleConfig,
143 state: std::marker::PhantomData<State>,
144 task_models: Option<HashMap<String, Vec<BaggingClassifier<Trained>>>>,
146 shared_models: Option<Vec<BaggingClassifier<Trained>>>,
147 task_similarities: Option<HashMap<(String, String), f64>>,
148 task_weights: Option<HashMap<String, f64>>,
149 feature_selector: Option<MultiTaskFeatureSelector>,
150 task_hierarchy: Option<TaskHierarchy>,
151}
152
153pub struct MultiTaskEnsembleRegressor<State = Untrained> {
155 config: MultiTaskEnsembleConfig,
156 state: std::marker::PhantomData<State>,
157 task_models: Option<HashMap<String, Vec<TrainedGradientBoostingRegressor>>>,
159 shared_models: Option<Vec<TrainedGradientBoostingRegressor>>,
160 task_similarities: Option<HashMap<(String, String), f64>>,
161 task_weights: Option<HashMap<String, f64>>,
162 feature_selector: Option<MultiTaskFeatureSelector>,
163 task_hierarchy: Option<TaskHierarchy>,
164}
165
166#[derive(Debug, Clone)]
168pub struct TaskHierarchy {
169 pub hierarchy: HashMap<String, Vec<String>>,
171 pub task_depths: HashMap<String, usize>,
173 pub hierarchy_weights: HashMap<(String, String), f64>,
175}
176
177#[derive(Debug, Clone)]
179pub struct MultiTaskFeatureSelector {
180 pub task_feature_masks: HashMap<String, Vec<bool>>,
182 pub shared_feature_mask: Vec<bool>,
184 pub task_feature_importances: HashMap<String, Vec<f64>>,
186 pub global_feature_importances: Vec<f64>,
188}
189
190#[derive(Debug, Clone)]
192pub struct TaskData {
193 pub task_id: String,
195 pub features: Array2<f64>,
197 pub labels: Vec<f64>,
199 pub sample_weights: Option<Vec<f64>>,
201 pub metadata: HashMap<String, String>,
203}
204
205#[derive(Debug, Clone)]
207pub struct MultiTaskTrainingResults {
208 pub task_metrics: HashMap<String, TaskMetrics>,
210 pub transfer_effects: HashMap<(String, String), f64>,
212 pub final_similarities: HashMap<(String, String), f64>,
214 pub convergence_info: ConvergenceInfo,
216}
217
218#[derive(Debug, Clone)]
220pub struct TaskMetrics {
221 pub training_score: f64,
223 pub validation_score: f64,
225 pub n_samples: usize,
227 pub training_time: f64,
229 pub complexity: f64,
231}
232
233#[derive(Debug, Clone)]
235pub struct ConvergenceInfo {
236 pub n_iterations: usize,
238 pub final_loss: f64,
240 pub tolerance_achieved: f64,
242 pub converged: bool,
244}
245
246impl MultiTaskEnsembleConfig {
247 pub fn builder() -> MultiTaskEnsembleConfigBuilder {
248 MultiTaskEnsembleConfigBuilder::default()
249 }
250}
251
252#[derive(Default)]
253pub struct MultiTaskEnsembleConfigBuilder {
254 config: MultiTaskEnsembleConfig,
255}
256
257impl MultiTaskEnsembleConfigBuilder {
258 pub fn n_estimators_per_task(mut self, n_estimators: usize) -> Self {
259 self.config.n_estimators_per_task = n_estimators;
260 self
261 }
262
263 pub fn sharing_strategy(mut self, strategy: TaskSharingStrategy) -> Self {
264 self.config.sharing_strategy = strategy;
265 self
266 }
267
268 pub fn similarity_metric(mut self, metric: TaskSimilarityMetric) -> Self {
269 self.config.similarity_metric = metric;
270 self
271 }
272
273 pub fn min_similarity_threshold(mut self, threshold: f64) -> Self {
274 self.config.min_similarity_threshold = threshold;
275 self
276 }
277
278 pub fn task_weighting(mut self, weighting: TaskWeightingStrategy) -> Self {
279 self.config.task_weighting = weighting;
280 self
281 }
282
283 pub fn use_task_specific_features(mut self, use_specific: bool) -> Self {
284 self.config.use_task_specific_features = use_specific;
285 self
286 }
287
288 pub fn sharing_regularization(mut self, regularization: f64) -> Self {
289 self.config.sharing_regularization = regularization;
290 self
291 }
292
293 pub fn cross_task_validation(mut self, validation: CrossTaskValidation) -> Self {
294 self.config.cross_task_validation = validation;
295 self
296 }
297
298 pub fn build(self) -> MultiTaskEnsembleConfig {
299 self.config
300 }
301}
302
303impl MultiTaskEnsembleRegressor {
304 pub fn new(config: MultiTaskEnsembleConfig) -> Self {
305 Self {
306 config,
307 state: std::marker::PhantomData,
308 task_models: None,
309 shared_models: None,
310 task_similarities: None,
311 task_weights: None,
312 feature_selector: None,
313 task_hierarchy: None,
314 }
315 }
316
317 pub fn builder() -> MultiTaskEnsembleRegressorBuilder {
318 MultiTaskEnsembleRegressorBuilder::new()
319 }
320
321 pub fn fit_tasks(
323 mut self,
324 tasks: &[TaskData],
325 ) -> SklResult<MultiTaskEnsembleRegressor<Trained>> {
326 if tasks.is_empty() {
327 return Err(SklearsError::InvalidInput("No tasks provided".to_string()));
328 }
329
330 self.initialize_task_weights(tasks)?;
332
333 if matches!(
335 self.config.sharing_strategy,
336 TaskSharingStrategy::HierarchicalSharing
337 ) {
338 self.build_task_hierarchy(tasks)?;
339 }
340
341 if self.config.use_task_specific_features {
343 self.initialize_feature_selector(tasks)?;
344 }
345
346 self.compute_task_similarities(tasks)?;
348
349 let training_results = match self.config.sharing_strategy {
351 TaskSharingStrategy::Independent => self.train_independent_tasks(tasks)?,
352 TaskSharingStrategy::SharedRepresentation => self.train_shared_representation(tasks)?,
353 TaskSharingStrategy::ParameterSharing => self.train_parameter_sharing(tasks)?,
354 TaskSharingStrategy::HierarchicalSharing => self.train_hierarchical_sharing(tasks)?,
355 TaskSharingStrategy::AdaptiveSharing => self.train_adaptive_sharing(tasks)?,
356 _ => self.train_independent_tasks(tasks)?, };
358
359 let fitted_ensemble = MultiTaskEnsembleRegressor::<Trained> {
360 config: self.config,
361 state: std::marker::PhantomData,
362 task_models: self.task_models,
363 shared_models: self.shared_models,
364 task_similarities: self.task_similarities,
365 task_weights: self.task_weights,
366 feature_selector: self.feature_selector,
367 task_hierarchy: self.task_hierarchy,
368 };
369
370 Ok(fitted_ensemble)
371 }
372
373 fn initialize_task_weights(&mut self, tasks: &[TaskData]) -> SklResult<()> {
375 let mut weights = HashMap::new();
376
377 match self.config.task_weighting {
378 TaskWeightingStrategy::Uniform => {
379 let weight = 1.0 / tasks.len() as f64;
380 for task in tasks {
381 weights.insert(task.task_id.clone(), weight);
382 }
383 }
384 TaskWeightingStrategy::SampleSizeBased => {
385 let total_samples: usize = tasks.iter().map(|t| t.features.shape()[0]).sum();
386 for task in tasks {
387 let weight = task.features.shape()[0] as f64 / total_samples as f64;
388 weights.insert(task.task_id.clone(), weight);
389 }
390 }
391 _ => {
392 let weight = 1.0 / tasks.len() as f64;
394 for task in tasks {
395 weights.insert(task.task_id.clone(), weight);
396 }
397 }
398 }
399
400 self.task_weights = Some(weights);
401 Ok(())
402 }
403
404 fn build_task_hierarchy(&mut self, tasks: &[TaskData]) -> SklResult<()> {
406 let mut hierarchy = HashMap::new();
407 let mut task_depths = HashMap::new();
408 let mut hierarchy_weights = HashMap::new();
409
410 for (i, task1) in tasks.iter().enumerate() {
413 task_depths.insert(task1.task_id.clone(), 0);
414 hierarchy.insert(task1.task_id.clone(), Vec::new());
415
416 for (j, task2) in tasks.iter().enumerate() {
417 if i != j {
418 let weight = 1.0 / (i.abs_diff(j) + 1) as f64;
420 hierarchy_weights
421 .insert((task1.task_id.clone(), task2.task_id.clone()), weight);
422 }
423 }
424 }
425
426 self.task_hierarchy = Some(TaskHierarchy {
427 hierarchy,
428 task_depths,
429 hierarchy_weights,
430 });
431
432 Ok(())
433 }
434
435 fn initialize_feature_selector(&mut self, tasks: &[TaskData]) -> SklResult<()> {
437 if tasks.is_empty() {
438 return Ok(());
439 }
440
441 let n_features = tasks[0].features.shape()[1];
442 let mut task_feature_masks = HashMap::new();
443 let mut task_feature_importances = HashMap::new();
444
445 for task in tasks {
447 let mut feature_mask = vec![true; n_features];
448 let mut feature_importances = vec![0.0; n_features];
449
450 for j in 0..n_features {
452 let column: Vec<f64> = (0..task.features.shape()[0])
453 .map(|i| task.features[[i, j]])
454 .collect();
455 let mean = column.iter().sum::<f64>() / column.len() as f64;
456 let variance =
457 column.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / column.len() as f64;
458
459 feature_importances[j] = variance;
460 feature_mask[j] = variance > 1e-8; }
462
463 task_feature_masks.insert(task.task_id.clone(), feature_mask);
464 task_feature_importances.insert(task.task_id.clone(), feature_importances);
465 }
466
467 let mut global_feature_importances = vec![0.0; n_features];
469 for j in 0..n_features {
470 let sum: f64 = task_feature_importances
471 .values()
472 .map(|importances| importances[j])
473 .sum();
474 global_feature_importances[j] = sum / tasks.len() as f64;
475 }
476
477 let shared_feature_mask: Vec<bool> = (0..n_features)
479 .map(|j| {
480 let important_count = task_feature_importances
481 .values()
482 .filter(|importances| importances[j] > global_feature_importances[j] * 0.5)
483 .count();
484 important_count >= tasks.len() / 2
485 })
486 .collect();
487
488 self.feature_selector = Some(MultiTaskFeatureSelector {
489 task_feature_masks,
490 shared_feature_mask,
491 task_feature_importances,
492 global_feature_importances,
493 });
494
495 Ok(())
496 }
497
498 fn compute_task_similarities(&mut self, tasks: &[TaskData]) -> SklResult<()> {
500 let mut similarities = HashMap::new();
501
502 for (i, task1) in tasks.iter().enumerate() {
503 for task2 in tasks.iter().skip(i + 1) {
504 let similarity = self.calculate_task_similarity(task1, task2)?;
505 similarities.insert((task1.task_id.clone(), task2.task_id.clone()), similarity);
506 similarities.insert((task2.task_id.clone(), task1.task_id.clone()), similarity);
507 }
508 }
509
510 self.task_similarities = Some(similarities);
511 Ok(())
512 }
513
514 fn calculate_task_similarity(&self, task1: &TaskData, task2: &TaskData) -> SklResult<f64> {
516 match self.config.similarity_metric {
517 TaskSimilarityMetric::CorrelationBased => self.correlation_similarity(task1, task2),
518 TaskSimilarityMetric::DistributionSimilarity => {
519 self.distribution_similarity(task1, task2)
520 }
521 _ => {
522 self.correlation_similarity(task1, task2)
524 }
525 }
526 }
527
528 fn correlation_similarity(&self, task1: &TaskData, task2: &TaskData) -> SklResult<f64> {
530 if task1.labels.len() != task2.labels.len() {
532 return Ok(0.0); }
534
535 let n = task1.labels.len();
536 if n < 2 {
537 return Ok(0.0);
538 }
539
540 let mean1 = task1.labels.iter().sum::<f64>() / n as f64;
541 let mean2 = task2.labels.iter().sum::<f64>() / n as f64;
542
543 let mut numerator = 0.0;
544 let mut denom1 = 0.0;
545 let mut denom2 = 0.0;
546
547 for i in 0..n {
548 let diff1 = task1.labels[i] - mean1;
549 let diff2 = task2.labels[i] - mean2;
550 numerator += diff1 * diff2;
551 denom1 += diff1 * diff1;
552 denom2 += diff2 * diff2;
553 }
554
555 if denom1 * denom2 > 0.0 {
556 Ok(numerator / (denom1 * denom2).sqrt())
557 } else {
558 Ok(0.0)
559 }
560 }
561
562 fn distribution_similarity(&self, task1: &TaskData, task2: &TaskData) -> SklResult<f64> {
564 let n_features1 = task1.features.shape()[1];
566 let n_features2 = task2.features.shape()[1];
567
568 if n_features1 != n_features2 {
569 return Ok(0.0);
570 }
571
572 let mut similarity_sum = 0.0;
573
574 for j in 0..n_features1 {
575 let col1: Vec<f64> = (0..task1.features.shape()[0])
576 .map(|i| task1.features[[i, j]])
577 .collect();
578 let col2: Vec<f64> = (0..task2.features.shape()[0])
579 .map(|i| task2.features[[i, j]])
580 .collect();
581
582 let mean1 = col1.iter().sum::<f64>() / col1.len() as f64;
583 let mean2 = col2.iter().sum::<f64>() / col2.len() as f64;
584
585 let var1 = col1.iter().map(|&x| (x - mean1).powi(2)).sum::<f64>() / col1.len() as f64;
586 let var2 = col2.iter().map(|&x| (x - mean2).powi(2)).sum::<f64>() / col2.len() as f64;
587
588 let mean_sim = 1.0 - (mean1 - mean2).abs() / (mean1.abs() + mean2.abs() + 1e-8);
590 let var_sim = 1.0 - (var1 - var2).abs() / (var1 + var2 + 1e-8);
591
592 similarity_sum += (mean_sim + var_sim) / 2.0;
593 }
594
595 Ok(similarity_sum / n_features1 as f64)
596 }
597
598 fn train_independent_tasks(
600 &mut self,
601 tasks: &[TaskData],
602 ) -> SklResult<MultiTaskTrainingResults> {
603 let mut task_metrics = HashMap::new();
604
605 if self.task_models.is_none() {
607 self.task_models = Some(HashMap::new());
608 }
609
610 for task in tasks {
611 let mut models = Vec::new();
612
613 for _ in 0..self.config.n_estimators_per_task {
614 let gb_config = GradientBoostingConfig {
615 n_estimators: 50,
616 learning_rate: 0.1,
617 max_depth: 6,
618 ..Default::default()
619 };
620
621 let y_array = Array1::from_vec(task.labels.clone());
622 let model = GradientBoostingRegressor::builder()
623 .n_estimators(50)
624 .learning_rate(0.1)
625 .max_depth(6)
626 .build()
627 .fit(&task.features, &y_array)?;
628
629 models.push(model);
630 }
631
632 let predictions = self.predict_task_ensemble(&models, &task.features)?;
634 let mse = self.calculate_mse(&predictions, &task.labels);
635
636 task_metrics.insert(
637 task.task_id.clone(),
638 TaskMetrics {
639 training_score: mse,
640 validation_score: mse, n_samples: task.features.shape()[0],
642 training_time: 0.0, complexity: models.len() as f64,
644 },
645 );
646
647 self.task_models
648 .as_mut()
649 .unwrap()
650 .insert(task.task_id.clone(), models);
651 }
652
653 Ok(MultiTaskTrainingResults {
654 task_metrics,
655 transfer_effects: HashMap::new(),
656 final_similarities: self.task_similarities.clone().unwrap_or_default(),
657 convergence_info: ConvergenceInfo {
658 n_iterations: 1,
659 final_loss: 0.0,
660 tolerance_achieved: 0.0,
661 converged: true,
662 },
663 })
664 }
665
666 fn train_shared_representation(
668 &mut self,
669 tasks: &[TaskData],
670 ) -> SklResult<MultiTaskTrainingResults> {
671 let combined_data = self.combine_task_data(tasks)?;
673
674 if self.shared_models.is_none() {
676 self.shared_models = Some(Vec::new());
677 }
678
679 for _ in 0..self.config.n_estimators_per_task {
680 let gb_config = GradientBoostingConfig {
681 n_estimators: 30,
682 learning_rate: 0.1,
683 max_depth: 4,
684 ..Default::default()
685 };
686
687 let y_array = Array1::from_vec(combined_data.labels.clone());
688 let shared_model = GradientBoostingRegressor::builder()
689 .n_estimators(30)
690 .learning_rate(0.1)
691 .max_depth(4)
692 .build()
693 .fit(&combined_data.features, &y_array)?;
694
695 self.shared_models.as_mut().unwrap().push(shared_model);
696 }
697
698 self.train_independent_tasks(tasks)
700 }
701
702 fn train_parameter_sharing(
704 &mut self,
705 tasks: &[TaskData],
706 ) -> SklResult<MultiTaskTrainingResults> {
707 let task_groups = self.group_similar_tasks(tasks)?;
709
710 let mut task_metrics = HashMap::new();
711
712 for group in task_groups {
713 let group_data = self.combine_group_data(tasks, &group)?;
715
716 let mut group_models = Vec::new();
717 for _ in 0..self.config.n_estimators_per_task {
718 let gb_config = GradientBoostingConfig {
719 n_estimators: 40,
720 learning_rate: 0.1,
721 max_depth: 5,
722 ..Default::default()
723 };
724
725 let y_array = Array1::from_vec(group_data.labels.clone());
726 let model = GradientBoostingRegressor::builder()
727 .n_estimators(40)
728 .learning_rate(0.1)
729 .max_depth(5)
730 .build()
731 .fit(&group_data.features, &y_array)?;
732
733 group_models.push(model);
734 }
735
736 for task_id in &group {
738 let task = tasks.iter().find(|t| &t.task_id == task_id).unwrap();
739
740 let predictions = self.predict_task_ensemble(&group_models, &task.features)?;
742 let mse = self.calculate_mse(&predictions, &task.labels);
743
744 task_metrics.insert(
745 task_id.clone(),
746 TaskMetrics {
747 training_score: mse,
748 validation_score: mse,
749 n_samples: task.features.shape()[0],
750 training_time: 0.0,
751 complexity: group_models.len() as f64,
752 },
753 );
754
755 let mut task_models = Vec::new();
757 for _ in 0..self.config.n_estimators_per_task {
758 let y_array = Array1::from_vec(group_data.labels.clone());
759 let model = GradientBoostingRegressor::builder()
760 .n_estimators(40)
761 .learning_rate(0.1)
762 .max_depth(5)
763 .build()
764 .fit(&group_data.features, &y_array)?;
765 task_models.push(model);
766 }
767
768 self.task_models
769 .as_mut()
770 .unwrap()
771 .insert(task_id.clone(), task_models);
772 }
773 }
774
775 Ok(MultiTaskTrainingResults {
776 task_metrics,
777 transfer_effects: HashMap::new(),
778 final_similarities: self.task_similarities.clone().unwrap_or_default(),
779 convergence_info: ConvergenceInfo {
780 n_iterations: 1,
781 final_loss: 0.0,
782 tolerance_achieved: 0.0,
783 converged: true,
784 },
785 })
786 }
787
788 fn train_hierarchical_sharing(
790 &mut self,
791 tasks: &[TaskData],
792 ) -> SklResult<MultiTaskTrainingResults> {
793 self.train_parameter_sharing(tasks)
796 }
797
798 fn train_adaptive_sharing(
800 &mut self,
801 tasks: &[TaskData],
802 ) -> SklResult<MultiTaskTrainingResults> {
803 let mut task_metrics = HashMap::new();
804
805 for task in tasks {
806 let mut models = Vec::new();
807
808 let similar_tasks = self.find_similar_tasks(&task.task_id, tasks);
810
811 if similar_tasks.len() > 1 {
812 let combined_data = self.combine_similar_task_data(tasks, &similar_tasks)?;
814
815 for _ in 0..self.config.n_estimators_per_task {
816 let gb_config = GradientBoostingConfig {
817 n_estimators: 50,
818 learning_rate: 0.1,
819 max_depth: 6,
820 ..Default::default()
821 };
822
823 let y_array = Array1::from_vec(combined_data.labels.clone());
824 let model = GradientBoostingRegressor::builder()
825 .n_estimators(50)
826 .learning_rate(0.1)
827 .max_depth(6)
828 .build()
829 .fit(&combined_data.features, &y_array)?;
830
831 models.push(model);
832 }
833 } else {
834 for _ in 0..self.config.n_estimators_per_task {
836 let gb_config = GradientBoostingConfig {
837 n_estimators: 50,
838 learning_rate: 0.1,
839 max_depth: 6,
840 ..Default::default()
841 };
842
843 let y_array = Array1::from_vec(task.labels.clone());
844 let model = GradientBoostingRegressor::builder()
845 .n_estimators(50)
846 .learning_rate(0.1)
847 .max_depth(6)
848 .build()
849 .fit(&task.features, &y_array)?;
850
851 models.push(model);
852 }
853 }
854
855 let predictions = self.predict_task_ensemble(&models, &task.features)?;
857 let mse = self.calculate_mse(&predictions, &task.labels);
858
859 task_metrics.insert(
860 task.task_id.clone(),
861 TaskMetrics {
862 training_score: mse,
863 validation_score: mse,
864 n_samples: task.features.shape()[0],
865 training_time: 0.0,
866 complexity: models.len() as f64,
867 },
868 );
869
870 self.task_models
871 .as_mut()
872 .unwrap()
873 .insert(task.task_id.clone(), models);
874 }
875
876 Ok(MultiTaskTrainingResults {
877 task_metrics,
878 transfer_effects: HashMap::new(),
879 final_similarities: self.task_similarities.clone().unwrap_or_default(),
880 convergence_info: ConvergenceInfo {
881 n_iterations: 1,
882 final_loss: 0.0,
883 tolerance_achieved: 0.0,
884 converged: true,
885 },
886 })
887 }
888
889 fn combine_task_data(&self, tasks: &[TaskData]) -> SklResult<TaskData> {
891 if tasks.is_empty() {
892 return Err(SklearsError::InvalidInput(
893 "No tasks to combine".to_string(),
894 ));
895 }
896
897 let total_samples: usize = tasks.iter().map(|t| t.features.shape()[0]).sum();
898 let n_features = tasks[0].features.shape()[1];
899
900 let mut combined_features = Vec::with_capacity(total_samples * n_features);
901 let mut combined_labels = Vec::with_capacity(total_samples);
902
903 for task in tasks {
904 for i in 0..task.features.shape()[0] {
905 for j in 0..n_features {
906 combined_features.push(task.features[[i, j]]);
907 }
908 combined_labels.push(task.labels[i]);
909 }
910 }
911
912 let features = Array2::from_shape_vec((total_samples, n_features), combined_features)?;
913
914 Ok(TaskData {
915 task_id: "combined".to_string(),
916 features,
917 labels: combined_labels,
918 sample_weights: None,
919 metadata: HashMap::new(),
920 })
921 }
922
923 fn group_similar_tasks(&self, tasks: &[TaskData]) -> SklResult<Vec<Vec<String>>> {
925 let mut groups = Vec::new();
926 let mut assigned = vec![false; tasks.len()];
927
928 for (i, task) in tasks.iter().enumerate() {
929 if assigned[i] {
930 continue;
931 }
932
933 let mut group = vec![task.task_id.clone()];
934 assigned[i] = true;
935
936 for (j, other_task) in tasks.iter().enumerate() {
938 if i != j && !assigned[j] {
939 let similarity = self
940 .task_similarities
941 .as_ref()
942 .and_then(|similarities| {
943 similarities.get(&(task.task_id.clone(), other_task.task_id.clone()))
944 })
945 .copied()
946 .unwrap_or(0.0);
947
948 if similarity >= self.config.min_similarity_threshold {
949 group.push(other_task.task_id.clone());
950 assigned[j] = true;
951 }
952 }
953 }
954
955 groups.push(group);
956 }
957
958 Ok(groups)
959 }
960
961 fn combine_group_data(&self, tasks: &[TaskData], group: &[String]) -> SklResult<TaskData> {
963 let group_tasks: Vec<TaskData> = tasks
964 .iter()
965 .filter(|t| group.contains(&t.task_id))
966 .cloned()
967 .collect();
968
969 self.combine_task_data(&group_tasks)
970 }
971
972 fn find_similar_tasks(&self, task_id: &str, tasks: &[TaskData]) -> Vec<String> {
974 let mut similar_tasks = vec![task_id.to_string()];
975
976 for task in tasks {
977 if task.task_id != task_id {
978 let similarity = self
979 .task_similarities
980 .as_ref()
981 .and_then(|similarities| {
982 similarities.get(&(task_id.to_string(), task.task_id.clone()))
983 })
984 .copied()
985 .unwrap_or(0.0);
986
987 if similarity >= self.config.min_similarity_threshold {
988 similar_tasks.push(task.task_id.clone());
989 }
990 }
991 }
992
993 similar_tasks
994 }
995
996 fn combine_similar_task_data(
998 &self,
999 tasks: &[TaskData],
1000 similar_task_ids: &[String],
1001 ) -> SklResult<TaskData> {
1002 let similar_tasks: Vec<TaskData> = tasks
1003 .iter()
1004 .filter(|t| similar_task_ids.contains(&t.task_id))
1005 .cloned()
1006 .collect();
1007
1008 self.combine_task_data(&similar_tasks)
1009 }
1010
1011 fn predict_task_ensemble(
1013 &self,
1014 models: &[TrainedGradientBoostingRegressor],
1015 X: &Array2<f64>,
1016 ) -> SklResult<Vec<f64>> {
1017 if models.is_empty() {
1018 return Err(SklearsError::InvalidInput(
1019 "No models in ensemble".to_string(),
1020 ));
1021 }
1022
1023 let mut predictions = vec![0.0; X.shape()[0]];
1024
1025 for model in models {
1026 let pred = model.predict(X)?;
1027 for (i, &p) in pred.iter().enumerate() {
1028 predictions[i] += p;
1029 }
1030 }
1031
1032 for p in &mut predictions {
1034 *p /= models.len() as f64;
1035 }
1036
1037 Ok(predictions)
1038 }
1039
1040 fn calculate_mse(&self, predictions: &[f64], targets: &[f64]) -> f64 {
1042 if predictions.len() != targets.len() {
1043 return f64::INFINITY;
1044 }
1045
1046 let sum_squared_error: f64 = predictions
1047 .iter()
1048 .zip(targets.iter())
1049 .map(|(&p, &t)| (p - t).powi(2))
1050 .sum();
1051
1052 sum_squared_error / predictions.len() as f64
1053 }
1054}
1055
1056impl MultiTaskEnsembleRegressor<Trained> {
1057 pub fn predict_task(&self, task_id: &str, X: &Array2<f64>) -> SklResult<Vec<f64>> {
1059 if let Some(models) = self.task_models.as_ref().and_then(|m| m.get(task_id)) {
1061 let mut predictions = self.predict_task_ensemble(models, X)?;
1062
1063 if let Some(shared_models) = self.shared_models.as_ref() {
1065 if !shared_models.is_empty() {
1066 let shared_predictions = self.predict_task_ensemble(shared_models, X)?;
1067 for (i, &shared_pred) in shared_predictions.iter().enumerate() {
1068 predictions[i] = 0.7 * predictions[i] + 0.3 * shared_pred;
1069 }
1070 }
1071 }
1072
1073 Ok(predictions)
1074 } else {
1075 Err(SklearsError::InvalidInput(format!(
1076 "Task '{}' not found",
1077 task_id
1078 )))
1079 }
1080 }
1081
1082 fn predict_task_ensemble(
1084 &self,
1085 models: &[TrainedGradientBoostingRegressor],
1086 X: &Array2<f64>,
1087 ) -> SklResult<Vec<f64>> {
1088 if models.is_empty() {
1089 return Err(SklearsError::InvalidInput(
1090 "No models in ensemble".to_string(),
1091 ));
1092 }
1093
1094 let mut predictions = vec![0.0; X.shape()[0]];
1095
1096 for model in models {
1097 let pred = model.predict(X)?;
1098 for (i, &p) in pred.iter().enumerate() {
1099 predictions[i] += p;
1100 }
1101 }
1102
1103 for p in &mut predictions {
1105 *p /= models.len() as f64;
1106 }
1107
1108 Ok(predictions)
1109 }
1110
1111 pub fn get_task_similarities(&self) -> &HashMap<(String, String), f64> {
1113 self.task_similarities.as_ref().expect("Model is trained")
1114 }
1115
1116 pub fn get_task_weights(&self) -> &HashMap<String, f64> {
1118 self.task_weights.as_ref().expect("Model is trained")
1119 }
1120}
1121
1122pub struct MultiTaskEnsembleRegressorBuilder {
1123 config: MultiTaskEnsembleConfig,
1124}
1125
1126impl Default for MultiTaskEnsembleRegressorBuilder {
1127 fn default() -> Self {
1128 Self::new()
1129 }
1130}
1131
1132impl MultiTaskEnsembleRegressorBuilder {
1133 pub fn new() -> Self {
1134 Self {
1135 config: MultiTaskEnsembleConfig::default(),
1136 }
1137 }
1138
1139 pub fn config(mut self, config: MultiTaskEnsembleConfig) -> Self {
1140 self.config = config;
1141 self
1142 }
1143
1144 pub fn n_estimators_per_task(mut self, n_estimators: usize) -> Self {
1145 self.config.n_estimators_per_task = n_estimators;
1146 self
1147 }
1148
1149 pub fn sharing_strategy(mut self, strategy: TaskSharingStrategy) -> Self {
1150 self.config.sharing_strategy = strategy;
1151 self
1152 }
1153
1154 pub fn min_similarity_threshold(mut self, threshold: f64) -> Self {
1155 self.config.min_similarity_threshold = threshold;
1156 self
1157 }
1158
1159 pub fn build(self) -> MultiTaskEnsembleRegressor {
1160 MultiTaskEnsembleRegressor::new(self.config)
1161 }
1162}
1163
1164impl Estimator for MultiTaskEnsembleRegressor {
1165 type Config = MultiTaskEnsembleConfig;
1166 type Error = SklearsError;
1167 type Float = f64;
1168
1169 fn config(&self) -> &Self::Config {
1170 &self.config
1171 }
1172}
1173
1174#[allow(non_snake_case)]
1175#[cfg(test)]
1176mod tests {
1177 use super::*;
1178 use scirs2_core::ndarray::Array2;
1179
1180 #[test]
1181 fn test_multi_task_config() {
1182 let config = MultiTaskEnsembleConfig::builder()
1183 .n_estimators_per_task(5)
1184 .sharing_strategy(TaskSharingStrategy::SharedRepresentation)
1185 .min_similarity_threshold(0.5)
1186 .build();
1187
1188 assert_eq!(config.n_estimators_per_task, 5);
1189 assert_eq!(
1190 config.sharing_strategy,
1191 TaskSharingStrategy::SharedRepresentation
1192 );
1193 assert_eq!(config.min_similarity_threshold, 0.5);
1194 }
1195
1196 #[test]
1197 fn test_task_data_creation() {
1198 let features = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1199 let labels = vec![1.0, 2.0, 3.0];
1200
1201 let task = TaskData {
1202 task_id: "test_task".to_string(),
1203 features,
1204 labels,
1205 sample_weights: None,
1206 metadata: HashMap::new(),
1207 };
1208
1209 assert_eq!(task.task_id, "test_task");
1210 assert_eq!(task.features.shape(), &[3, 2]);
1211 assert_eq!(task.labels.len(), 3);
1212 }
1213
1214 #[test]
1215 fn test_multi_task_ensemble_basic() {
1216 let config = MultiTaskEnsembleConfig::builder()
1217 .n_estimators_per_task(2)
1218 .sharing_strategy(TaskSharingStrategy::Independent)
1219 .build();
1220
1221 let ensemble = MultiTaskEnsembleRegressor::new(config);
1222
1223 assert_eq!(ensemble.config.n_estimators_per_task, 2);
1225 assert_eq!(
1226 ensemble.config.sharing_strategy,
1227 TaskSharingStrategy::Independent
1228 );
1229 assert!(ensemble.task_models.is_none());
1231 }
1232
1233 #[test]
1234 fn test_task_similarity_calculation() {
1235 let config = MultiTaskEnsembleConfig::default();
1236 let ensemble = MultiTaskEnsembleRegressor::new(config);
1237
1238 let task1 = TaskData {
1239 task_id: "task1".to_string(),
1240 features: Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap(),
1241 labels: vec![1.0, 2.0],
1242 sample_weights: None,
1243 metadata: HashMap::new(),
1244 };
1245
1246 let task2 = TaskData {
1247 task_id: "task2".to_string(),
1248 features: Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap(),
1249 labels: vec![1.0, 2.0], sample_weights: None,
1251 metadata: HashMap::new(),
1252 };
1253
1254 let similarity = ensemble.correlation_similarity(&task1, &task2).unwrap();
1255 assert!((similarity - 1.0).abs() < 1e-10); }
1257}