1use crate::bagging::BaggingClassifier;
8use crate::gradient_boosting::{
9 GradientBoostingConfig, GradientBoostingRegressor, TrainedGradientBoostingRegressor,
10};
11use scirs2_core::ndarray::{Array1, Array2};
12use sklears_core::{
13 error::Result as SklResult,
14 prelude::{Predict, SklearsError},
15 traits::{Estimator, Fit, Trained, Untrained},
16};
17use std::collections::HashMap;
18
19pub trait MultiTaskEstimator<C, E, F, X, Y>:
21 Estimator<Config = C, Error = E, Float = F> + Predict<X, Y>
22{
23}
24
25impl<T, C, E, F, X, Y> MultiTaskEstimator<C, E, F, X, Y> for T where
26 T: Estimator<Config = C, Error = E, Float = F> + Predict<X, Y>
27{
28}
29
30#[derive(Debug, Clone)]
32pub struct MultiTaskEnsembleConfig {
33 pub n_estimators_per_task: usize,
35 pub sharing_strategy: TaskSharingStrategy,
37 pub similarity_metric: TaskSimilarityMetric,
39 pub min_similarity_threshold: f64,
41 pub task_weighting: TaskWeightingStrategy,
43 pub use_task_specific_features: bool,
45 pub n_shared_features: Option<usize>,
47 pub sharing_regularization: f64,
49 pub max_task_depth: usize,
51 pub cross_task_validation: CrossTaskValidation,
53}
54
55impl Default for MultiTaskEnsembleConfig {
56 fn default() -> Self {
57 Self {
58 n_estimators_per_task: 10,
59 sharing_strategy: TaskSharingStrategy::SharedRepresentation,
60 similarity_metric: TaskSimilarityMetric::CorrelationBased,
61 min_similarity_threshold: 0.3,
62 task_weighting: TaskWeightingStrategy::Uniform,
63 use_task_specific_features: true,
64 n_shared_features: None,
65 sharing_regularization: 0.1,
66 max_task_depth: 3,
67 cross_task_validation: CrossTaskValidation::LeaveOneTaskOut,
68 }
69 }
70}
71
72#[derive(Debug, Clone, PartialEq)]
74pub enum TaskSharingStrategy {
75 Independent,
77 SharedRepresentation,
79 ParameterSharing,
81 HierarchicalSharing,
83 AdaptiveSharing,
85 MultiLevelSharing,
87 TransferLearning,
89}
90
91#[derive(Debug, Clone, PartialEq)]
93pub enum TaskSimilarityMetric {
94 CorrelationBased,
96 FeatureImportanceSimilarity,
98 PredictionSimilarity,
100 DistributionSimilarity,
102 GradientSimilarity,
104 PerformanceCorrelation,
106}
107
108#[derive(Debug, Clone, PartialEq)]
110pub enum TaskWeightingStrategy {
111 Uniform,
113 DifficultyBased,
115 SampleSizeBased,
117 PerformanceBased,
119 AdaptiveWeighting,
121 ImportanceBased,
123}
124
125#[derive(Debug, Clone, PartialEq)]
127pub enum CrossTaskValidation {
128 LeaveOneTaskOut,
130 WithinTaskCV,
132 HierarchicalCV,
134 TemporalCV,
136 StratifiedCV,
138}
139
140pub struct MultiTaskEnsembleClassifier<State = Untrained> {
142 config: MultiTaskEnsembleConfig,
143 state: std::marker::PhantomData<State>,
144 task_models: Option<HashMap<String, Vec<BaggingClassifier<Trained>>>>,
146 shared_models: Option<Vec<BaggingClassifier<Trained>>>,
147 task_similarities: Option<HashMap<(String, String), f64>>,
148 task_weights: Option<HashMap<String, f64>>,
149 feature_selector: Option<MultiTaskFeatureSelector>,
150 task_hierarchy: Option<TaskHierarchy>,
151}
152
153pub struct MultiTaskEnsembleRegressor<State = Untrained> {
155 config: MultiTaskEnsembleConfig,
156 state: std::marker::PhantomData<State>,
157 task_models: Option<HashMap<String, Vec<TrainedGradientBoostingRegressor>>>,
159 shared_models: Option<Vec<TrainedGradientBoostingRegressor>>,
160 task_similarities: Option<HashMap<(String, String), f64>>,
161 task_weights: Option<HashMap<String, f64>>,
162 feature_selector: Option<MultiTaskFeatureSelector>,
163 task_hierarchy: Option<TaskHierarchy>,
164}
165
166#[derive(Debug, Clone)]
168pub struct TaskHierarchy {
169 pub hierarchy: HashMap<String, Vec<String>>,
171 pub task_depths: HashMap<String, usize>,
173 pub hierarchy_weights: HashMap<(String, String), f64>,
175}
176
177#[derive(Debug, Clone)]
179pub struct MultiTaskFeatureSelector {
180 pub task_feature_masks: HashMap<String, Vec<bool>>,
182 pub shared_feature_mask: Vec<bool>,
184 pub task_feature_importances: HashMap<String, Vec<f64>>,
186 pub global_feature_importances: Vec<f64>,
188}
189
190#[derive(Debug, Clone)]
192pub struct TaskData {
193 pub task_id: String,
195 pub features: Array2<f64>,
197 pub labels: Vec<f64>,
199 pub sample_weights: Option<Vec<f64>>,
201 pub metadata: HashMap<String, String>,
203}
204
205#[derive(Debug, Clone)]
207pub struct MultiTaskTrainingResults {
208 pub task_metrics: HashMap<String, TaskMetrics>,
210 pub transfer_effects: HashMap<(String, String), f64>,
212 pub final_similarities: HashMap<(String, String), f64>,
214 pub convergence_info: ConvergenceInfo,
216}
217
218#[derive(Debug, Clone)]
220pub struct TaskMetrics {
221 pub training_score: f64,
223 pub validation_score: f64,
225 pub n_samples: usize,
227 pub training_time: f64,
229 pub complexity: f64,
231}
232
233#[derive(Debug, Clone)]
235pub struct ConvergenceInfo {
236 pub n_iterations: usize,
238 pub final_loss: f64,
240 pub tolerance_achieved: f64,
242 pub converged: bool,
244}
245
246impl MultiTaskEnsembleConfig {
247 pub fn builder() -> MultiTaskEnsembleConfigBuilder {
248 MultiTaskEnsembleConfigBuilder::default()
249 }
250}
251
252#[derive(Default)]
253pub struct MultiTaskEnsembleConfigBuilder {
254 config: MultiTaskEnsembleConfig,
255}
256
257impl MultiTaskEnsembleConfigBuilder {
258 pub fn n_estimators_per_task(mut self, n_estimators: usize) -> Self {
259 self.config.n_estimators_per_task = n_estimators;
260 self
261 }
262
263 pub fn sharing_strategy(mut self, strategy: TaskSharingStrategy) -> Self {
264 self.config.sharing_strategy = strategy;
265 self
266 }
267
268 pub fn similarity_metric(mut self, metric: TaskSimilarityMetric) -> Self {
269 self.config.similarity_metric = metric;
270 self
271 }
272
273 pub fn min_similarity_threshold(mut self, threshold: f64) -> Self {
274 self.config.min_similarity_threshold = threshold;
275 self
276 }
277
278 pub fn task_weighting(mut self, weighting: TaskWeightingStrategy) -> Self {
279 self.config.task_weighting = weighting;
280 self
281 }
282
283 pub fn use_task_specific_features(mut self, use_specific: bool) -> Self {
284 self.config.use_task_specific_features = use_specific;
285 self
286 }
287
288 pub fn sharing_regularization(mut self, regularization: f64) -> Self {
289 self.config.sharing_regularization = regularization;
290 self
291 }
292
293 pub fn cross_task_validation(mut self, validation: CrossTaskValidation) -> Self {
294 self.config.cross_task_validation = validation;
295 self
296 }
297
298 pub fn build(self) -> MultiTaskEnsembleConfig {
299 self.config
300 }
301}
302
303impl MultiTaskEnsembleRegressor {
304 pub fn new(config: MultiTaskEnsembleConfig) -> Self {
305 Self {
306 config,
307 state: std::marker::PhantomData,
308 task_models: None,
309 shared_models: None,
310 task_similarities: None,
311 task_weights: None,
312 feature_selector: None,
313 task_hierarchy: None,
314 }
315 }
316
317 pub fn builder() -> MultiTaskEnsembleRegressorBuilder {
318 MultiTaskEnsembleRegressorBuilder::new()
319 }
320
321 pub fn fit_tasks(
323 mut self,
324 tasks: &[TaskData],
325 ) -> SklResult<MultiTaskEnsembleRegressor<Trained>> {
326 if tasks.is_empty() {
327 return Err(SklearsError::InvalidInput("No tasks provided".to_string()));
328 }
329
330 self.initialize_task_weights(tasks)?;
332
333 if matches!(
335 self.config.sharing_strategy,
336 TaskSharingStrategy::HierarchicalSharing
337 ) {
338 self.build_task_hierarchy(tasks)?;
339 }
340
341 if self.config.use_task_specific_features {
343 self.initialize_feature_selector(tasks)?;
344 }
345
346 self.compute_task_similarities(tasks)?;
348
349 let training_results = match self.config.sharing_strategy {
351 TaskSharingStrategy::Independent => self.train_independent_tasks(tasks)?,
352 TaskSharingStrategy::SharedRepresentation => self.train_shared_representation(tasks)?,
353 TaskSharingStrategy::ParameterSharing => self.train_parameter_sharing(tasks)?,
354 TaskSharingStrategy::HierarchicalSharing => self.train_hierarchical_sharing(tasks)?,
355 TaskSharingStrategy::AdaptiveSharing => self.train_adaptive_sharing(tasks)?,
356 _ => self.train_independent_tasks(tasks)?, };
358
359 let fitted_ensemble = MultiTaskEnsembleRegressor::<Trained> {
360 config: self.config,
361 state: std::marker::PhantomData,
362 task_models: self.task_models,
363 shared_models: self.shared_models,
364 task_similarities: self.task_similarities,
365 task_weights: self.task_weights,
366 feature_selector: self.feature_selector,
367 task_hierarchy: self.task_hierarchy,
368 };
369
370 Ok(fitted_ensemble)
371 }
372
373 fn initialize_task_weights(&mut self, tasks: &[TaskData]) -> SklResult<()> {
375 let mut weights = HashMap::new();
376
377 match self.config.task_weighting {
378 TaskWeightingStrategy::Uniform => {
379 let weight = 1.0 / tasks.len() as f64;
380 for task in tasks {
381 weights.insert(task.task_id.clone(), weight);
382 }
383 }
384 TaskWeightingStrategy::SampleSizeBased => {
385 let total_samples: usize = tasks.iter().map(|t| t.features.shape()[0]).sum();
386 for task in tasks {
387 let weight = task.features.shape()[0] as f64 / total_samples as f64;
388 weights.insert(task.task_id.clone(), weight);
389 }
390 }
391 _ => {
392 let weight = 1.0 / tasks.len() as f64;
394 for task in tasks {
395 weights.insert(task.task_id.clone(), weight);
396 }
397 }
398 }
399
400 self.task_weights = Some(weights);
401 Ok(())
402 }
403
404 fn build_task_hierarchy(&mut self, tasks: &[TaskData]) -> SklResult<()> {
406 let mut hierarchy = HashMap::new();
407 let mut task_depths = HashMap::new();
408 let mut hierarchy_weights = HashMap::new();
409
410 for (i, task1) in tasks.iter().enumerate() {
413 task_depths.insert(task1.task_id.clone(), 0);
414 hierarchy.insert(task1.task_id.clone(), Vec::new());
415
416 for (j, task2) in tasks.iter().enumerate() {
417 if i != j {
418 let weight = 1.0 / (i.abs_diff(j) + 1) as f64;
420 hierarchy_weights
421 .insert((task1.task_id.clone(), task2.task_id.clone()), weight);
422 }
423 }
424 }
425
426 self.task_hierarchy = Some(TaskHierarchy {
427 hierarchy,
428 task_depths,
429 hierarchy_weights,
430 });
431
432 Ok(())
433 }
434
435 fn initialize_feature_selector(&mut self, tasks: &[TaskData]) -> SklResult<()> {
437 if tasks.is_empty() {
438 return Ok(());
439 }
440
441 let n_features = tasks[0].features.shape()[1];
442 let mut task_feature_masks = HashMap::new();
443 let mut task_feature_importances = HashMap::new();
444
445 for task in tasks {
447 let mut feature_mask = vec![true; n_features];
448 let mut feature_importances = vec![0.0; n_features];
449
450 for j in 0..n_features {
452 let column: Vec<f64> = (0..task.features.shape()[0])
453 .map(|i| task.features[[i, j]])
454 .collect();
455 let mean = column.iter().sum::<f64>() / column.len() as f64;
456 let variance =
457 column.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / column.len() as f64;
458
459 feature_importances[j] = variance;
460 feature_mask[j] = variance > 1e-8; }
462
463 task_feature_masks.insert(task.task_id.clone(), feature_mask);
464 task_feature_importances.insert(task.task_id.clone(), feature_importances);
465 }
466
467 let mut global_feature_importances = vec![0.0; n_features];
469 for j in 0..n_features {
470 let sum: f64 = task_feature_importances
471 .values()
472 .map(|importances| importances[j])
473 .sum();
474 global_feature_importances[j] = sum / tasks.len() as f64;
475 }
476
477 let shared_feature_mask: Vec<bool> = (0..n_features)
479 .map(|j| {
480 let important_count = task_feature_importances
481 .values()
482 .filter(|importances| importances[j] > global_feature_importances[j] * 0.5)
483 .count();
484 important_count >= tasks.len() / 2
485 })
486 .collect();
487
488 self.feature_selector = Some(MultiTaskFeatureSelector {
489 task_feature_masks,
490 shared_feature_mask,
491 task_feature_importances,
492 global_feature_importances,
493 });
494
495 Ok(())
496 }
497
498 fn compute_task_similarities(&mut self, tasks: &[TaskData]) -> SklResult<()> {
500 let mut similarities = HashMap::new();
501
502 for (i, task1) in tasks.iter().enumerate() {
503 for task2 in tasks.iter().skip(i + 1) {
504 let similarity = self.calculate_task_similarity(task1, task2)?;
505 similarities.insert((task1.task_id.clone(), task2.task_id.clone()), similarity);
506 similarities.insert((task2.task_id.clone(), task1.task_id.clone()), similarity);
507 }
508 }
509
510 self.task_similarities = Some(similarities);
511 Ok(())
512 }
513
514 fn calculate_task_similarity(&self, task1: &TaskData, task2: &TaskData) -> SklResult<f64> {
516 match self.config.similarity_metric {
517 TaskSimilarityMetric::CorrelationBased => self.correlation_similarity(task1, task2),
518 TaskSimilarityMetric::DistributionSimilarity => {
519 self.distribution_similarity(task1, task2)
520 }
521 _ => {
522 self.correlation_similarity(task1, task2)
524 }
525 }
526 }
527
528 fn correlation_similarity(&self, task1: &TaskData, task2: &TaskData) -> SklResult<f64> {
530 if task1.labels.len() != task2.labels.len() {
532 return Ok(0.0); }
534
535 let n = task1.labels.len();
536 if n < 2 {
537 return Ok(0.0);
538 }
539
540 let mean1 = task1.labels.iter().sum::<f64>() / n as f64;
541 let mean2 = task2.labels.iter().sum::<f64>() / n as f64;
542
543 let mut numerator = 0.0;
544 let mut denom1 = 0.0;
545 let mut denom2 = 0.0;
546
547 for i in 0..n {
548 let diff1 = task1.labels[i] - mean1;
549 let diff2 = task2.labels[i] - mean2;
550 numerator += diff1 * diff2;
551 denom1 += diff1 * diff1;
552 denom2 += diff2 * diff2;
553 }
554
555 if denom1 * denom2 > 0.0 {
556 Ok(numerator / (denom1 * denom2).sqrt())
557 } else {
558 Ok(0.0)
559 }
560 }
561
562 fn distribution_similarity(&self, task1: &TaskData, task2: &TaskData) -> SklResult<f64> {
564 let n_features1 = task1.features.shape()[1];
566 let n_features2 = task2.features.shape()[1];
567
568 if n_features1 != n_features2 {
569 return Ok(0.0);
570 }
571
572 let mut similarity_sum = 0.0;
573
574 for j in 0..n_features1 {
575 let col1: Vec<f64> = (0..task1.features.shape()[0])
576 .map(|i| task1.features[[i, j]])
577 .collect();
578 let col2: Vec<f64> = (0..task2.features.shape()[0])
579 .map(|i| task2.features[[i, j]])
580 .collect();
581
582 let mean1 = col1.iter().sum::<f64>() / col1.len() as f64;
583 let mean2 = col2.iter().sum::<f64>() / col2.len() as f64;
584
585 let var1 = col1.iter().map(|&x| (x - mean1).powi(2)).sum::<f64>() / col1.len() as f64;
586 let var2 = col2.iter().map(|&x| (x - mean2).powi(2)).sum::<f64>() / col2.len() as f64;
587
588 let mean_sim = 1.0 - (mean1 - mean2).abs() / (mean1.abs() + mean2.abs() + 1e-8);
590 let var_sim = 1.0 - (var1 - var2).abs() / (var1 + var2 + 1e-8);
591
592 similarity_sum += (mean_sim + var_sim) / 2.0;
593 }
594
595 Ok(similarity_sum / n_features1 as f64)
596 }
597
598 fn train_independent_tasks(
600 &mut self,
601 tasks: &[TaskData],
602 ) -> SklResult<MultiTaskTrainingResults> {
603 let mut task_metrics = HashMap::new();
604
605 if self.task_models.is_none() {
607 self.task_models = Some(HashMap::new());
608 }
609
610 for task in tasks {
611 let mut models = Vec::new();
612
613 for _ in 0..self.config.n_estimators_per_task {
614 let gb_config = GradientBoostingConfig {
615 n_estimators: 50,
616 learning_rate: 0.1,
617 max_depth: 6,
618 ..Default::default()
619 };
620
621 let y_array = Array1::from_vec(task.labels.clone());
622 let model = GradientBoostingRegressor::builder()
623 .n_estimators(50)
624 .learning_rate(0.1)
625 .max_depth(6)
626 .build()
627 .fit(&task.features, &y_array)?;
628
629 models.push(model);
630 }
631
632 let predictions = self.predict_task_ensemble(&models, &task.features)?;
634 let mse = self.calculate_mse(&predictions, &task.labels);
635
636 task_metrics.insert(
637 task.task_id.clone(),
638 TaskMetrics {
639 training_score: mse,
640 validation_score: mse, n_samples: task.features.shape()[0],
642 training_time: 0.0, complexity: models.len() as f64,
644 },
645 );
646
647 self.task_models
648 .as_mut()
649 .expect("operation should succeed")
650 .insert(task.task_id.clone(), models);
651 }
652
653 Ok(MultiTaskTrainingResults {
654 task_metrics,
655 transfer_effects: HashMap::new(),
656 final_similarities: self.task_similarities.clone().unwrap_or_default(),
657 convergence_info: ConvergenceInfo {
658 n_iterations: 1,
659 final_loss: 0.0,
660 tolerance_achieved: 0.0,
661 converged: true,
662 },
663 })
664 }
665
666 fn train_shared_representation(
668 &mut self,
669 tasks: &[TaskData],
670 ) -> SklResult<MultiTaskTrainingResults> {
671 let combined_data = self.combine_task_data(tasks)?;
673
674 if self.shared_models.is_none() {
676 self.shared_models = Some(Vec::new());
677 }
678
679 for _ in 0..self.config.n_estimators_per_task {
680 let gb_config = GradientBoostingConfig {
681 n_estimators: 30,
682 learning_rate: 0.1,
683 max_depth: 4,
684 ..Default::default()
685 };
686
687 let y_array = Array1::from_vec(combined_data.labels.clone());
688 let shared_model = GradientBoostingRegressor::builder()
689 .n_estimators(30)
690 .learning_rate(0.1)
691 .max_depth(4)
692 .build()
693 .fit(&combined_data.features, &y_array)?;
694
695 self.shared_models
696 .as_mut()
697 .expect("operation should succeed")
698 .push(shared_model);
699 }
700
701 self.train_independent_tasks(tasks)
703 }
704
705 fn train_parameter_sharing(
707 &mut self,
708 tasks: &[TaskData],
709 ) -> SklResult<MultiTaskTrainingResults> {
710 let task_groups = self.group_similar_tasks(tasks)?;
712
713 let mut task_metrics = HashMap::new();
714
715 for group in task_groups {
716 let group_data = self.combine_group_data(tasks, &group)?;
718
719 let mut group_models = Vec::new();
720 for _ in 0..self.config.n_estimators_per_task {
721 let gb_config = GradientBoostingConfig {
722 n_estimators: 40,
723 learning_rate: 0.1,
724 max_depth: 5,
725 ..Default::default()
726 };
727
728 let y_array = Array1::from_vec(group_data.labels.clone());
729 let model = GradientBoostingRegressor::builder()
730 .n_estimators(40)
731 .learning_rate(0.1)
732 .max_depth(5)
733 .build()
734 .fit(&group_data.features, &y_array)?;
735
736 group_models.push(model);
737 }
738
739 for task_id in &group {
741 let task = tasks
742 .iter()
743 .find(|t| &t.task_id == task_id)
744 .expect("operation should succeed");
745
746 let predictions = self.predict_task_ensemble(&group_models, &task.features)?;
748 let mse = self.calculate_mse(&predictions, &task.labels);
749
750 task_metrics.insert(
751 task_id.clone(),
752 TaskMetrics {
753 training_score: mse,
754 validation_score: mse,
755 n_samples: task.features.shape()[0],
756 training_time: 0.0,
757 complexity: group_models.len() as f64,
758 },
759 );
760
761 let mut task_models = Vec::new();
763 for _ in 0..self.config.n_estimators_per_task {
764 let y_array = Array1::from_vec(group_data.labels.clone());
765 let model = GradientBoostingRegressor::builder()
766 .n_estimators(40)
767 .learning_rate(0.1)
768 .max_depth(5)
769 .build()
770 .fit(&group_data.features, &y_array)?;
771 task_models.push(model);
772 }
773
774 self.task_models
775 .as_mut()
776 .expect("operation should succeed")
777 .insert(task_id.clone(), task_models);
778 }
779 }
780
781 Ok(MultiTaskTrainingResults {
782 task_metrics,
783 transfer_effects: HashMap::new(),
784 final_similarities: self.task_similarities.clone().unwrap_or_default(),
785 convergence_info: ConvergenceInfo {
786 n_iterations: 1,
787 final_loss: 0.0,
788 tolerance_achieved: 0.0,
789 converged: true,
790 },
791 })
792 }
793
794 fn train_hierarchical_sharing(
796 &mut self,
797 tasks: &[TaskData],
798 ) -> SklResult<MultiTaskTrainingResults> {
799 self.train_parameter_sharing(tasks)
802 }
803
804 fn train_adaptive_sharing(
806 &mut self,
807 tasks: &[TaskData],
808 ) -> SklResult<MultiTaskTrainingResults> {
809 let mut task_metrics = HashMap::new();
810
811 for task in tasks {
812 let mut models = Vec::new();
813
814 let similar_tasks = self.find_similar_tasks(&task.task_id, tasks);
816
817 if similar_tasks.len() > 1 {
818 let combined_data = self.combine_similar_task_data(tasks, &similar_tasks)?;
820
821 for _ in 0..self.config.n_estimators_per_task {
822 let gb_config = GradientBoostingConfig {
823 n_estimators: 50,
824 learning_rate: 0.1,
825 max_depth: 6,
826 ..Default::default()
827 };
828
829 let y_array = Array1::from_vec(combined_data.labels.clone());
830 let model = GradientBoostingRegressor::builder()
831 .n_estimators(50)
832 .learning_rate(0.1)
833 .max_depth(6)
834 .build()
835 .fit(&combined_data.features, &y_array)?;
836
837 models.push(model);
838 }
839 } else {
840 for _ in 0..self.config.n_estimators_per_task {
842 let gb_config = GradientBoostingConfig {
843 n_estimators: 50,
844 learning_rate: 0.1,
845 max_depth: 6,
846 ..Default::default()
847 };
848
849 let y_array = Array1::from_vec(task.labels.clone());
850 let model = GradientBoostingRegressor::builder()
851 .n_estimators(50)
852 .learning_rate(0.1)
853 .max_depth(6)
854 .build()
855 .fit(&task.features, &y_array)?;
856
857 models.push(model);
858 }
859 }
860
861 let predictions = self.predict_task_ensemble(&models, &task.features)?;
863 let mse = self.calculate_mse(&predictions, &task.labels);
864
865 task_metrics.insert(
866 task.task_id.clone(),
867 TaskMetrics {
868 training_score: mse,
869 validation_score: mse,
870 n_samples: task.features.shape()[0],
871 training_time: 0.0,
872 complexity: models.len() as f64,
873 },
874 );
875
876 self.task_models
877 .as_mut()
878 .expect("operation should succeed")
879 .insert(task.task_id.clone(), models);
880 }
881
882 Ok(MultiTaskTrainingResults {
883 task_metrics,
884 transfer_effects: HashMap::new(),
885 final_similarities: self.task_similarities.clone().unwrap_or_default(),
886 convergence_info: ConvergenceInfo {
887 n_iterations: 1,
888 final_loss: 0.0,
889 tolerance_achieved: 0.0,
890 converged: true,
891 },
892 })
893 }
894
895 fn combine_task_data(&self, tasks: &[TaskData]) -> SklResult<TaskData> {
897 if tasks.is_empty() {
898 return Err(SklearsError::InvalidInput(
899 "No tasks to combine".to_string(),
900 ));
901 }
902
903 let total_samples: usize = tasks.iter().map(|t| t.features.shape()[0]).sum();
904 let n_features = tasks[0].features.shape()[1];
905
906 let mut combined_features = Vec::with_capacity(total_samples * n_features);
907 let mut combined_labels = Vec::with_capacity(total_samples);
908
909 for task in tasks {
910 for i in 0..task.features.shape()[0] {
911 for j in 0..n_features {
912 combined_features.push(task.features[[i, j]]);
913 }
914 combined_labels.push(task.labels[i]);
915 }
916 }
917
918 let features = Array2::from_shape_vec((total_samples, n_features), combined_features)?;
919
920 Ok(TaskData {
921 task_id: "combined".to_string(),
922 features,
923 labels: combined_labels,
924 sample_weights: None,
925 metadata: HashMap::new(),
926 })
927 }
928
929 fn group_similar_tasks(&self, tasks: &[TaskData]) -> SklResult<Vec<Vec<String>>> {
931 let mut groups = Vec::new();
932 let mut assigned = vec![false; tasks.len()];
933
934 for (i, task) in tasks.iter().enumerate() {
935 if assigned[i] {
936 continue;
937 }
938
939 let mut group = vec![task.task_id.clone()];
940 assigned[i] = true;
941
942 for (j, other_task) in tasks.iter().enumerate() {
944 if i != j && !assigned[j] {
945 let similarity = self
946 .task_similarities
947 .as_ref()
948 .and_then(|similarities| {
949 similarities.get(&(task.task_id.clone(), other_task.task_id.clone()))
950 })
951 .copied()
952 .unwrap_or(0.0);
953
954 if similarity >= self.config.min_similarity_threshold {
955 group.push(other_task.task_id.clone());
956 assigned[j] = true;
957 }
958 }
959 }
960
961 groups.push(group);
962 }
963
964 Ok(groups)
965 }
966
967 fn combine_group_data(&self, tasks: &[TaskData], group: &[String]) -> SklResult<TaskData> {
969 let group_tasks: Vec<TaskData> = tasks
970 .iter()
971 .filter(|t| group.contains(&t.task_id))
972 .cloned()
973 .collect();
974
975 self.combine_task_data(&group_tasks)
976 }
977
978 fn find_similar_tasks(&self, task_id: &str, tasks: &[TaskData]) -> Vec<String> {
980 let mut similar_tasks = vec![task_id.to_string()];
981
982 for task in tasks {
983 if task.task_id != task_id {
984 let similarity = self
985 .task_similarities
986 .as_ref()
987 .and_then(|similarities| {
988 similarities.get(&(task_id.to_string(), task.task_id.clone()))
989 })
990 .copied()
991 .unwrap_or(0.0);
992
993 if similarity >= self.config.min_similarity_threshold {
994 similar_tasks.push(task.task_id.clone());
995 }
996 }
997 }
998
999 similar_tasks
1000 }
1001
1002 fn combine_similar_task_data(
1004 &self,
1005 tasks: &[TaskData],
1006 similar_task_ids: &[String],
1007 ) -> SklResult<TaskData> {
1008 let similar_tasks: Vec<TaskData> = tasks
1009 .iter()
1010 .filter(|t| similar_task_ids.contains(&t.task_id))
1011 .cloned()
1012 .collect();
1013
1014 self.combine_task_data(&similar_tasks)
1015 }
1016
1017 fn predict_task_ensemble(
1019 &self,
1020 models: &[TrainedGradientBoostingRegressor],
1021 X: &Array2<f64>,
1022 ) -> SklResult<Vec<f64>> {
1023 if models.is_empty() {
1024 return Err(SklearsError::InvalidInput(
1025 "No models in ensemble".to_string(),
1026 ));
1027 }
1028
1029 let mut predictions = vec![0.0; X.shape()[0]];
1030
1031 for model in models {
1032 let pred = model.predict(X)?;
1033 for (i, &p) in pred.iter().enumerate() {
1034 predictions[i] += p;
1035 }
1036 }
1037
1038 for p in &mut predictions {
1040 *p /= models.len() as f64;
1041 }
1042
1043 Ok(predictions)
1044 }
1045
1046 fn calculate_mse(&self, predictions: &[f64], targets: &[f64]) -> f64 {
1048 if predictions.len() != targets.len() {
1049 return f64::INFINITY;
1050 }
1051
1052 let sum_squared_error: f64 = predictions
1053 .iter()
1054 .zip(targets.iter())
1055 .map(|(&p, &t)| (p - t).powi(2))
1056 .sum();
1057
1058 sum_squared_error / predictions.len() as f64
1059 }
1060}
1061
1062impl MultiTaskEnsembleRegressor<Trained> {
1063 pub fn predict_task(&self, task_id: &str, X: &Array2<f64>) -> SklResult<Vec<f64>> {
1065 if let Some(models) = self.task_models.as_ref().and_then(|m| m.get(task_id)) {
1067 let mut predictions = self.predict_task_ensemble(models, X)?;
1068
1069 if let Some(shared_models) = self.shared_models.as_ref() {
1071 if !shared_models.is_empty() {
1072 let shared_predictions = self.predict_task_ensemble(shared_models, X)?;
1073 for (i, &shared_pred) in shared_predictions.iter().enumerate() {
1074 predictions[i] = 0.7 * predictions[i] + 0.3 * shared_pred;
1075 }
1076 }
1077 }
1078
1079 Ok(predictions)
1080 } else {
1081 Err(SklearsError::InvalidInput(format!(
1082 "Task '{}' not found",
1083 task_id
1084 )))
1085 }
1086 }
1087
1088 fn predict_task_ensemble(
1090 &self,
1091 models: &[TrainedGradientBoostingRegressor],
1092 X: &Array2<f64>,
1093 ) -> SklResult<Vec<f64>> {
1094 if models.is_empty() {
1095 return Err(SklearsError::InvalidInput(
1096 "No models in ensemble".to_string(),
1097 ));
1098 }
1099
1100 let mut predictions = vec![0.0; X.shape()[0]];
1101
1102 for model in models {
1103 let pred = model.predict(X)?;
1104 for (i, &p) in pred.iter().enumerate() {
1105 predictions[i] += p;
1106 }
1107 }
1108
1109 for p in &mut predictions {
1111 *p /= models.len() as f64;
1112 }
1113
1114 Ok(predictions)
1115 }
1116
1117 pub fn get_task_similarities(&self) -> &HashMap<(String, String), f64> {
1119 self.task_similarities.as_ref().expect("Model is trained")
1120 }
1121
1122 pub fn get_task_weights(&self) -> &HashMap<String, f64> {
1124 self.task_weights.as_ref().expect("Model is trained")
1125 }
1126}
1127
1128pub struct MultiTaskEnsembleRegressorBuilder {
1129 config: MultiTaskEnsembleConfig,
1130}
1131
1132impl Default for MultiTaskEnsembleRegressorBuilder {
1133 fn default() -> Self {
1134 Self::new()
1135 }
1136}
1137
1138impl MultiTaskEnsembleRegressorBuilder {
1139 pub fn new() -> Self {
1140 Self {
1141 config: MultiTaskEnsembleConfig::default(),
1142 }
1143 }
1144
1145 pub fn config(mut self, config: MultiTaskEnsembleConfig) -> Self {
1146 self.config = config;
1147 self
1148 }
1149
1150 pub fn n_estimators_per_task(mut self, n_estimators: usize) -> Self {
1151 self.config.n_estimators_per_task = n_estimators;
1152 self
1153 }
1154
1155 pub fn sharing_strategy(mut self, strategy: TaskSharingStrategy) -> Self {
1156 self.config.sharing_strategy = strategy;
1157 self
1158 }
1159
1160 pub fn min_similarity_threshold(mut self, threshold: f64) -> Self {
1161 self.config.min_similarity_threshold = threshold;
1162 self
1163 }
1164
1165 pub fn build(self) -> MultiTaskEnsembleRegressor {
1166 MultiTaskEnsembleRegressor::new(self.config)
1167 }
1168}
1169
1170impl Estimator for MultiTaskEnsembleRegressor {
1171 type Config = MultiTaskEnsembleConfig;
1172 type Error = SklearsError;
1173 type Float = f64;
1174
1175 fn config(&self) -> &Self::Config {
1176 &self.config
1177 }
1178}
1179
1180#[allow(non_snake_case)]
1181#[cfg(test)]
1182mod tests {
1183 use super::*;
1184 use scirs2_core::ndarray::Array2;
1185
1186 #[test]
1187 fn test_multi_task_config() {
1188 let config = MultiTaskEnsembleConfig::builder()
1189 .n_estimators_per_task(5)
1190 .sharing_strategy(TaskSharingStrategy::SharedRepresentation)
1191 .min_similarity_threshold(0.5)
1192 .build();
1193
1194 assert_eq!(config.n_estimators_per_task, 5);
1195 assert_eq!(
1196 config.sharing_strategy,
1197 TaskSharingStrategy::SharedRepresentation
1198 );
1199 assert_eq!(config.min_similarity_threshold, 0.5);
1200 }
1201
1202 #[test]
1203 fn test_task_data_creation() {
1204 let features = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1205 .expect("shape and data length should match");
1206 let labels = vec![1.0, 2.0, 3.0];
1207
1208 let task = TaskData {
1209 task_id: "test_task".to_string(),
1210 features,
1211 labels,
1212 sample_weights: None,
1213 metadata: HashMap::new(),
1214 };
1215
1216 assert_eq!(task.task_id, "test_task");
1217 assert_eq!(task.features.shape(), &[3, 2]);
1218 assert_eq!(task.labels.len(), 3);
1219 }
1220
1221 #[test]
1222 fn test_multi_task_ensemble_basic() {
1223 let config = MultiTaskEnsembleConfig::builder()
1224 .n_estimators_per_task(2)
1225 .sharing_strategy(TaskSharingStrategy::Independent)
1226 .build();
1227
1228 let ensemble = MultiTaskEnsembleRegressor::new(config);
1229
1230 assert_eq!(ensemble.config.n_estimators_per_task, 2);
1232 assert_eq!(
1233 ensemble.config.sharing_strategy,
1234 TaskSharingStrategy::Independent
1235 );
1236 assert!(ensemble.task_models.is_none());
1238 }
1239
1240 #[test]
1241 fn test_task_similarity_calculation() {
1242 let config = MultiTaskEnsembleConfig::default();
1243 let ensemble = MultiTaskEnsembleRegressor::new(config);
1244
1245 let task1 = TaskData {
1246 task_id: "task1".to_string(),
1247 features: Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0])
1248 .expect("shape and data length should match"),
1249 labels: vec![1.0, 2.0],
1250 sample_weights: None,
1251 metadata: HashMap::new(),
1252 };
1253
1254 let task2 = TaskData {
1255 task_id: "task2".to_string(),
1256 features: Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0])
1257 .expect("shape and data length should match"),
1258 labels: vec![1.0, 2.0], sample_weights: None,
1260 metadata: HashMap::new(),
1261 };
1262
1263 let similarity = ensemble
1264 .correlation_similarity(&task1, &task2)
1265 .expect("operation should succeed");
1266 assert!((similarity - 1.0).abs() < 1e-10); }
1268}