1use serde::{Deserialize, Serialize};
41use std::collections::HashMap;
42use trustformers_core::{
43 errors::invalid_input,
44 layers::Linear,
45 tensor::Tensor,
46 traits::{Layer, Model},
47 Result,
48};
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct MTLConfig {
53 pub architecture: MTLArchitecture,
55 pub loss_balancing: LossBalancingStrategy,
57 pub tasks: Vec<TaskConfig>,
59 pub use_task_embeddings: bool,
61 pub task_embedding_dim: usize,
63 pub use_auxiliary_tasks: bool,
65 pub auxiliary_tasks: Vec<AuxiliaryTaskConfig>,
67 pub task_clustering: Option<TaskClusteringConfig>,
69 pub evaluation_frequency: usize,
71 pub use_task_scheduling: bool,
73 pub task_scheduling: TaskSchedulingStrategy,
75}
76
77impl Default for MTLConfig {
78 fn default() -> Self {
79 Self {
80 architecture: MTLArchitecture::HardParameterSharing {
81 shared_layers: 8,
82 task_specific_layers: 2,
83 },
84 loss_balancing: LossBalancingStrategy::EqualWeighting,
85 tasks: Vec::new(),
86 use_task_embeddings: false,
87 task_embedding_dim: 64,
88 use_auxiliary_tasks: false,
89 auxiliary_tasks: Vec::new(),
90 task_clustering: None,
91 evaluation_frequency: 1000,
92 use_task_scheduling: false,
93 task_scheduling: TaskSchedulingStrategy::RoundRobin,
94 }
95 }
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
100pub enum MTLArchitecture {
101 HardParameterSharing {
103 shared_layers: usize,
104 task_specific_layers: usize,
105 },
106 SoftParameterSharing {
108 regularization_weight: f32,
109 regularization_type: RegularizationType,
110 },
111 MultiGateMixtureOfExperts {
113 num_experts: usize,
114 expert_dim: usize,
115 num_gates: usize,
116 },
117 CrossStitchNetworks {
119 num_tasks: usize,
120 cross_stitch_layers: Vec<usize>,
121 },
122 TaskRoutingNetworks {
124 num_routers: usize,
125 routing_dim: usize,
126 },
127 ProgressiveNetworks {
129 lateral_connections: bool,
130 adapter_layers: bool,
131 },
132 AttentionBasedSharing {
134 attention_dim: usize,
135 num_attention_heads: usize,
136 },
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
141pub enum RegularizationType {
142 L2Regularization,
144 TraceNorm,
146 GroupLasso,
148 ElasticNet { l1_weight: f32, l2_weight: f32 },
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize)]
154pub enum LossBalancingStrategy {
155 EqualWeighting,
157 ManualWeighting { weights: Vec<f32> },
159 UncertaintyWeighting,
161 DynamicWeightAverage,
163 GradNorm { alpha: f32 },
165 TaskBalancedSampling,
167 FocalLoss { gamma: f32 },
169 MetaLearning { meta_lr: f32 },
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct TaskConfig {
176 pub name: String,
178 pub task_type: TaskType,
180 pub weight: f32,
182 pub priority: TaskPriority,
184 pub is_main_task: bool,
186 pub learning_rate: Option<f32>,
188 pub batch_size: Option<usize>,
190}
191
192impl TaskConfig {
193 pub fn new(name: &str, task_type: TaskType) -> Self {
194 Self {
195 name: name.to_string(),
196 task_type,
197 weight: 1.0,
198 priority: TaskPriority::Normal,
199 is_main_task: false,
200 learning_rate: None,
201 batch_size: None,
202 }
203 }
204
205 pub fn with_weight(mut self, weight: f32) -> Self {
206 self.weight = weight;
207 self
208 }
209
210 pub fn with_priority(mut self, priority: TaskPriority) -> Self {
211 self.priority = priority;
212 self
213 }
214
215 pub fn as_main_task(mut self) -> Self {
216 self.is_main_task = true;
217 self
218 }
219}
220
221#[derive(Debug, Clone, Serialize, Deserialize)]
223pub enum TaskType {
224 Classification {
226 num_classes: usize,
227 use_class_weights: bool,
228 },
229 Regression {
231 output_dim: usize,
232 loss_type: RegressionLossType,
233 },
234 SequenceLabeling { num_labels: usize, use_crf: bool },
236 Generation {
238 vocab_size: usize,
239 max_length: usize,
240 },
241 Ranking { ranking_type: RankingType },
243 Auxiliary { auxiliary_type: AuxiliaryType },
245}
246
247#[derive(Debug, Clone, Serialize, Deserialize)]
249pub enum RegressionLossType {
250 MSE,
251 MAE,
252 Huber { delta: f32 },
253 LogCosh,
254}
255
256#[derive(Debug, Clone, Serialize, Deserialize)]
258pub enum RankingType {
259 Pairwise,
260 Listwise,
261 Pointwise,
262}
263
264#[derive(Debug, Clone, Serialize, Deserialize)]
266pub enum AuxiliaryType {
267 LanguageModeling,
268 MaskedLanguageModeling,
269 NextSentencePrediction,
270 SentenceOrderPrediction,
271 WordOrderPrediction,
272 Custom { name: String },
273}
274
275#[derive(Debug, Clone, Serialize, Deserialize)]
277pub enum TaskPriority {
278 Low,
279 Normal,
280 High,
281 Critical,
282}
283
284#[derive(Debug, Clone, Serialize, Deserialize)]
286pub struct AuxiliaryTaskConfig {
287 pub name: String,
288 pub auxiliary_type: AuxiliaryType,
289 pub weight: f32,
290 pub frequency: AuxiliaryTaskFrequency,
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
295pub enum AuxiliaryTaskFrequency {
296 EveryNSteps(usize),
298 WithProbability(f32),
300 Continuous,
302 EpochRange { start: usize, end: usize },
304}
305
306#[derive(Debug, Clone, Serialize, Deserialize)]
308pub struct TaskClusteringConfig {
309 pub clustering_method: ClusteringMethod,
310 pub num_clusters: usize,
311 pub update_frequency: usize,
312}
313
314#[derive(Debug, Clone, Serialize, Deserialize)]
316pub enum ClusteringMethod {
317 GradientSimilarity,
319 PerformanceCorrelation,
321 DataSimilarity,
323 Manual { clusters: Vec<Vec<String>> },
325}
326
327#[derive(Debug, Clone, Serialize, Deserialize)]
329pub enum TaskSchedulingStrategy {
330 RoundRobin,
332 WeightedSampling,
334 PerformanceBased,
336 CurriculumBased { difficulty_order: Vec<String> },
338 Random,
340}
341
342pub struct MultiTaskLearningTrainer<M: Model> {
344 pub base_model: M,
346 pub task_heads: HashMap<String, TaskHead>,
348 pub config: MTLConfig,
350 pub task_weights: HashMap<String, f32>,
352 pub task_performance: HashMap<String, Vec<f32>>,
354 pub step_counter: usize,
356 pub scheduler_state: TaskSchedulerState,
358 pub gradient_stats: HashMap<String, GradientStats>,
360}
361
362impl<M: Model<Input = Tensor, Output = Tensor>> MultiTaskLearningTrainer<M> {
363 pub fn new(base_model: M, config: MTLConfig) -> Result<Self> {
365 let mut task_heads = HashMap::new();
366 let mut task_weights = HashMap::new();
367
368 for task_config in &config.tasks {
370 let task_head = TaskHead::new(&task_config.task_type)?;
371 task_heads.insert(task_config.name.clone(), task_head);
372 task_weights.insert(task_config.name.clone(), task_config.weight);
373 }
374
375 let scheduler_state = TaskSchedulerState::new(&config.task_scheduling);
376
377 Ok(Self {
378 base_model,
379 task_heads,
380 config,
381 task_weights,
382 task_performance: HashMap::new(),
383 step_counter: 0,
384 scheduler_state,
385 gradient_stats: HashMap::new(),
386 })
387 }
388
389 pub fn train_multi_task_step(
391 &mut self,
392 task_data: &HashMap<String, TaskBatch>,
393 ) -> Result<MultiTaskOutput> {
394 let mut task_losses = HashMap::new();
395 let mut task_accuracies = HashMap::new();
396 let mut total_loss = Tensor::zeros(&[1])?;
397
398 let active_tasks = self.get_active_tasks(task_data)?;
400
401 for task_name in &active_tasks {
402 if let Some(batch) = task_data.get(task_name) {
403 let shared_features = self.base_model.forward(batch.inputs.clone())?;
405
406 let task_head = self
408 .task_heads
409 .get(task_name)
410 .ok_or_else(|| anyhow::anyhow!("Task head not found: {}", task_name))?;
411
412 let task_outputs = task_head.forward(&shared_features)?;
413 let task_loss = self.compute_task_loss(task_name, &task_outputs, &batch.targets)?;
414 let task_accuracy =
415 self.compute_task_accuracy(task_name, &task_outputs, &batch.targets)?;
416
417 task_losses.insert(task_name.clone(), task_loss.clone());
418 task_accuracies.insert(task_name.clone(), task_accuracy);
419
420 self.task_performance.entry(task_name.clone()).or_default().push(task_accuracy);
422 }
423 }
424
425 let balanced_losses = self.balance_losses(&task_losses)?;
427
428 for (task_name, loss) in &balanced_losses {
430 let weight = self.task_weights.get(task_name).copied().unwrap_or(1.0);
431 total_loss = total_loss.add(&loss.scalar_mul(weight)?)?;
432 }
433
434 self.update_task_weights(&task_losses)?;
436
437 if self.config.use_auxiliary_tasks {
439 let aux_loss = self.compute_auxiliary_losses(task_data)?;
440 total_loss = total_loss.add(&aux_loss)?;
441 }
442
443 self.step_counter += 1;
444
445 Ok(MultiTaskOutput {
446 total_loss,
447 task_losses: task_losses
448 .into_iter()
449 .map(|(k, v)| (k, v.to_scalar().unwrap_or(0.0)))
450 .collect(),
451 task_accuracies,
452 active_tasks,
453 task_weights: self.task_weights.clone(),
454 })
455 }
456
457 fn get_active_tasks(&mut self, task_data: &HashMap<String, TaskBatch>) -> Result<Vec<String>> {
459 match &self.config.task_scheduling {
460 TaskSchedulingStrategy::RoundRobin => {
461 let task_names: Vec<String> = task_data.keys().cloned().collect();
462 if task_names.is_empty() {
463 return Ok(Vec::new());
464 }
465 let current_task = &task_names[self.step_counter % task_names.len()];
466 Ok(vec![current_task.clone()])
467 },
468 TaskSchedulingStrategy::WeightedSampling => {
469 let mut weighted_tasks = Vec::new();
471 for task_config in &self.config.tasks {
472 if task_data.contains_key(&task_config.name) {
473 let weight = match task_config.priority {
474 TaskPriority::Low => 0.5,
475 TaskPriority::Normal => 1.0,
476 TaskPriority::High => 2.0,
477 TaskPriority::Critical => 3.0,
478 };
479 for _ in 0..(weight * 10.0) as usize {
480 weighted_tasks.push(task_config.name.clone());
481 }
482 }
483 }
484 if weighted_tasks.is_empty() {
485 return Ok(Vec::new());
486 }
487 let selected_task = &weighted_tasks[self.step_counter % weighted_tasks.len()];
488 Ok(vec![selected_task.clone()])
489 },
490 TaskSchedulingStrategy::Random => {
491 let task_names: Vec<String> = task_data.keys().cloned().collect();
492 if task_names.is_empty() {
493 return Ok(Vec::new());
494 }
495 let random_idx = fastrand::usize(..task_names.len());
496 Ok(vec![task_names[random_idx].clone()])
497 },
498 _ => {
499 Ok(task_data.keys().cloned().collect())
501 },
502 }
503 }
504
505 fn balance_losses(
507 &self,
508 task_losses: &HashMap<String, Tensor>,
509 ) -> Result<HashMap<String, Tensor>> {
510 match &self.config.loss_balancing {
511 LossBalancingStrategy::EqualWeighting => Ok(task_losses.clone()),
512 LossBalancingStrategy::ManualWeighting { weights } => {
513 let mut balanced = HashMap::new();
514 for (i, (task_name, loss)) in task_losses.iter().enumerate() {
515 let weight = weights.get(i).copied().unwrap_or(1.0);
516 balanced.insert(task_name.clone(), loss.scalar_mul(weight)?);
517 }
518 Ok(balanced)
519 },
520 LossBalancingStrategy::UncertaintyWeighting => {
521 Ok(task_losses.clone()) },
525 LossBalancingStrategy::DynamicWeightAverage => {
526 self.apply_dynamic_weight_average(task_losses)
528 },
529 LossBalancingStrategy::GradNorm { alpha } => {
530 self.apply_gradnorm(task_losses, *alpha)
532 },
533 _ => Ok(task_losses.clone()),
534 }
535 }
536
537 fn apply_dynamic_weight_average(
539 &self,
540 task_losses: &HashMap<String, Tensor>,
541 ) -> Result<HashMap<String, Tensor>> {
542 let mut balanced = HashMap::new();
544
545 if self.step_counter < 2 {
546 return Ok(task_losses.clone());
547 }
548
549 let temperature = 2.0; for (task_name, loss) in task_losses {
552 let prev_loss = self.get_previous_task_loss(task_name);
554 let current_loss = loss.to_scalar().unwrap_or(0.0);
555
556 let weight = if prev_loss > 0.0 {
557 let relative_decrease = current_loss / prev_loss;
558 (relative_decrease / temperature).exp()
559 } else {
560 1.0
561 };
562
563 balanced.insert(task_name.clone(), loss.clone().mul_scalar(weight)?);
564 }
565
566 Ok(balanced)
567 }
568
569 fn apply_gradnorm(
571 &self,
572 task_losses: &HashMap<String, Tensor>,
573 _alpha: f32,
574 ) -> Result<HashMap<String, Tensor>> {
575 Ok(task_losses.clone())
578 }
579
580 fn update_task_weights(&mut self, task_losses: &HashMap<String, Tensor>) -> Result<()> {
582 match &self.config.loss_balancing {
583 LossBalancingStrategy::DynamicWeightAverage => {
584 for (task_name, loss) in task_losses {
586 let current_loss = loss.to_scalar().unwrap_or(0.0);
587 if let Some(weight) = self.task_weights.get_mut(task_name) {
590 *weight = (*weight * 0.9 + current_loss * 0.1).clamp(0.1, 10.0);
591 }
592 }
593 },
594 _ => {
595 },
597 }
598 Ok(())
599 }
600
601 fn get_previous_task_loss(&self, _task_name: &str) -> f32 {
603 1.0
606 }
607
608 fn compute_auxiliary_losses(&self, task_data: &HashMap<String, TaskBatch>) -> Result<Tensor> {
610 let mut aux_loss: Tensor = Tensor::zeros(&[1])?;
611
612 for aux_config in &self.config.auxiliary_tasks {
613 if self.should_train_auxiliary_task(aux_config) {
614 if let Some(aux_data) = task_data.get(&aux_config.name) {
615 let aux_task_loss: Tensor =
616 self.compute_auxiliary_task_loss(aux_config, aux_data)?;
617 let weighted_loss: Tensor = aux_task_loss.mul_scalar(aux_config.weight)?;
618 aux_loss = aux_loss.add(&weighted_loss)?;
619 }
620 }
621 }
622
623 Ok(aux_loss)
624 }
625
626 fn should_train_auxiliary_task(&self, aux_config: &AuxiliaryTaskConfig) -> bool {
628 match &aux_config.frequency {
629 AuxiliaryTaskFrequency::EveryNSteps(n) => self.step_counter % n == 0,
630 AuxiliaryTaskFrequency::WithProbability(p) => fastrand::f32() < *p,
631 AuxiliaryTaskFrequency::Continuous => true,
632 AuxiliaryTaskFrequency::EpochRange { start, end } => {
633 let current_epoch = self.step_counter / 1000; current_epoch >= *start && current_epoch <= *end
635 },
636 }
637 }
638
639 fn compute_auxiliary_task_loss(
641 &self,
642 aux_config: &AuxiliaryTaskConfig,
643 data: &TaskBatch,
644 ) -> Result<Tensor> {
645 let shared_features: Tensor = self.base_model.forward(data.inputs.clone())?;
647
648 match &aux_config.auxiliary_type {
649 AuxiliaryType::LanguageModeling => {
650 self.compute_lm_loss(&shared_features, &data.targets)
652 },
653 AuxiliaryType::MaskedLanguageModeling => {
654 self.compute_mlm_loss(&shared_features, &data.targets)
656 },
657 _ => {
658 Ok(Tensor::zeros(&[1])?)
660 },
661 }
662 }
663
664 fn compute_lm_loss(&self, _features: &Tensor, _targets: &Tensor) -> Result<Tensor> {
666 Tensor::zeros(&[1])
668 }
669
670 fn compute_mlm_loss(&self, _features: &Tensor, _targets: &Tensor) -> Result<Tensor> {
672 Tensor::zeros(&[1])
674 }
675
676 fn compute_task_loss(
678 &self,
679 task_name: &str,
680 outputs: &Tensor,
681 targets: &Tensor,
682 ) -> Result<Tensor> {
683 let task_config = self
684 .config
685 .tasks
686 .iter()
687 .find(|t| t.name == task_name)
688 .ok_or_else(|| invalid_input(format!("Task not found: {}", task_name)))?;
689
690 match &task_config.task_type {
691 TaskType::Classification { .. } => {
692 let log_probs = outputs.softmax(-1)?;
694 let nll_loss = targets.mul(&log_probs)?.sum(Some(vec![1]), false)?;
695 Ok(nll_loss.mean()?.mul_scalar(-1.0)?)
696 },
697 TaskType::Regression { loss_type, .. } => {
698 match loss_type {
699 RegressionLossType::MSE => {
700 let diff = outputs.sub(targets)?;
701 Ok(diff.mul(&diff)?.mean()?)
702 },
703 RegressionLossType::MAE => {
704 let diff = outputs.sub(targets)?;
705 Ok(diff.abs()?.mean()?)
706 },
707 RegressionLossType::Huber { delta } => {
708 let diff = outputs.sub(targets)?;
709 let abs_diff = diff.abs()?;
710 let small_loss = diff.mul(&diff)?.mul_scalar(0.5)?;
711 let _large_loss =
712 abs_diff.mul_scalar(*delta)?.sub_scalar(*delta * *delta * 0.5)?;
713 Ok(small_loss.mean()?)
715 },
716 _ => {
717 let diff = outputs.sub(targets)?;
719 Ok(diff.mul(&diff)?.mean()?)
720 },
721 }
722 },
723 _ => {
724 Ok(Tensor::zeros(&[1])?)
726 },
727 }
728 }
729
730 fn compute_task_accuracy(
732 &self,
733 task_name: &str,
734 outputs: &Tensor,
735 targets: &Tensor,
736 ) -> Result<f32> {
737 let task_config = self
738 .config
739 .tasks
740 .iter()
741 .find(|t| t.name == task_name)
742 .ok_or_else(|| invalid_input(format!("Task not found: {}", task_name)))?;
743
744 match &task_config.task_type {
745 TaskType::Classification { .. } => {
746 let predicted = outputs.argmax(-1)?;
747 let target_class = targets.argmax(-1)?;
748 let correct = (predicted.to_scalar().unwrap_or(-1.0)
749 == target_class.to_scalar().unwrap_or(-2.0))
750 as i32 as f32;
751 Ok(correct)
752 },
753 TaskType::Regression { .. } => {
754 let diff = outputs.sub(targets)?;
756 let mse = diff.mul(&diff)?.mean()?;
757 let mean_targets = targets.mean()?;
758 let diff_from_mean = targets.sub(&mean_targets)?;
759 let variance = diff_from_mean.pow_scalar(2.0)?.mean()?;
760 let r_squared =
761 1.0 - mse.to_scalar().unwrap_or(1.0) / variance.to_scalar().unwrap_or(1.0);
762 Ok(r_squared.max(0.0))
763 },
764 _ => Ok(0.0),
765 }
766 }
767
768 pub fn evaluate_all_tasks(
770 &self,
771 test_data: &HashMap<String, TaskBatch>,
772 ) -> Result<MultiTaskEvaluation> {
773 let mut task_evaluations = HashMap::new();
774
775 for (task_name, batch) in test_data {
776 if let Some(task_head) = self.task_heads.get(task_name) {
777 let shared_features = self.base_model.forward(batch.inputs.clone())?;
778 let task_outputs = task_head.forward(&shared_features)?;
779 let loss = self.compute_task_loss(task_name, &task_outputs, &batch.targets)?;
780 let accuracy =
781 self.compute_task_accuracy(task_name, &task_outputs, &batch.targets)?;
782
783 task_evaluations.insert(
784 task_name.clone(),
785 TaskEvaluation {
786 task_name: task_name.clone(),
787 loss: loss.to_scalar().unwrap_or(0.0),
788 accuracy,
789 num_examples: batch.inputs.shape()[0],
790 },
791 );
792 }
793 }
794
795 let overall_accuracy = if !task_evaluations.is_empty() {
796 task_evaluations.values().map(|e| e.accuracy).sum::<f32>()
797 / task_evaluations.len() as f32
798 } else {
799 0.0
800 };
801
802 Ok(MultiTaskEvaluation {
803 task_evaluations,
804 overall_accuracy,
805 step: self.step_counter,
806 })
807 }
808
809 pub fn get_mtl_stats(&self) -> MTLStats {
811 MTLStats {
812 num_tasks: self.config.tasks.len(),
813 task_weights: self.task_weights.clone(),
814 step_counter: self.step_counter,
815 architecture: self.config.architecture.clone(),
816 loss_balancing: self.config.loss_balancing.clone(),
817 }
818 }
819}
820
821pub struct TaskHead {
823 layers: Vec<Linear>,
824 #[allow(dead_code)]
825 task_type: TaskType,
826}
827
828impl TaskHead {
829 pub fn new(task_type: &TaskType) -> Result<Self> {
830 let mut layers = Vec::new();
831
832 match task_type {
833 TaskType::Classification { num_classes, .. } => {
834 layers.push(Linear::new(768, *num_classes, true)); },
837 TaskType::Regression { output_dim, .. } => {
838 layers.push(Linear::new(768, *output_dim, true));
839 },
840 _ => {
841 layers.push(Linear::new(768, 768, true));
843 },
844 }
845
846 Ok(Self {
847 layers,
848 task_type: task_type.clone(),
849 })
850 }
851
852 pub fn forward(&self, input: &Tensor) -> Result<Tensor> {
853 let mut output = input.clone();
854 for layer in &self.layers {
855 output = layer.forward(output)?;
856 }
857 Ok(output)
858 }
859}
860
861#[derive(Debug, Clone)]
863pub struct TaskBatch {
864 pub inputs: Tensor,
865 pub targets: Tensor,
866 pub task_name: String,
867}
868
869pub struct TaskSchedulerState {
871 pub current_task_index: usize,
872 pub task_counters: HashMap<String, usize>,
873}
874
875impl TaskSchedulerState {
876 pub fn new(_strategy: &TaskSchedulingStrategy) -> Self {
877 Self {
878 current_task_index: 0,
879 task_counters: HashMap::new(),
880 }
881 }
882}
883
884#[derive(Debug, Clone)]
886pub struct GradientStats {
887 pub gradient_norm: f32,
888 pub gradient_variance: f32,
889 pub update_count: usize,
890}
891
892#[derive(Debug, Clone)]
894pub struct MultiTaskOutput {
895 pub total_loss: Tensor,
896 pub task_losses: HashMap<String, f32>,
897 pub task_accuracies: HashMap<String, f32>,
898 pub active_tasks: Vec<String>,
899 pub task_weights: HashMap<String, f32>,
900}
901
902#[derive(Debug, Clone)]
904pub struct TaskEvaluation {
905 pub task_name: String,
906 pub loss: f32,
907 pub accuracy: f32,
908 pub num_examples: usize,
909}
910
911#[derive(Debug, Clone)]
913pub struct MultiTaskEvaluation {
914 pub task_evaluations: HashMap<String, TaskEvaluation>,
915 pub overall_accuracy: f32,
916 pub step: usize,
917}
918
919#[derive(Debug, Clone)]
921pub struct MTLStats {
922 pub num_tasks: usize,
923 pub task_weights: HashMap<String, f32>,
924 pub step_counter: usize,
925 pub architecture: MTLArchitecture,
926 pub loss_balancing: LossBalancingStrategy,
927}
928
929pub mod utils {
931 use super::*;
932
933 pub fn hard_parameter_sharing_config(
935 tasks: Vec<TaskConfig>,
936 shared_layers: usize,
937 task_specific_layers: usize,
938 ) -> MTLConfig {
939 MTLConfig {
940 architecture: MTLArchitecture::HardParameterSharing {
941 shared_layers,
942 task_specific_layers,
943 },
944 tasks,
945 ..Default::default()
946 }
947 }
948
949 pub fn soft_parameter_sharing_config(
951 tasks: Vec<TaskConfig>,
952 regularization_weight: f32,
953 ) -> MTLConfig {
954 MTLConfig {
955 architecture: MTLArchitecture::SoftParameterSharing {
956 regularization_weight,
957 regularization_type: RegularizationType::L2Regularization,
958 },
959 tasks,
960 ..Default::default()
961 }
962 }
963
964 pub fn mmoe_config(tasks: Vec<TaskConfig>, num_experts: usize, expert_dim: usize) -> MTLConfig {
966 MTLConfig {
967 architecture: MTLArchitecture::MultiGateMixtureOfExperts {
968 num_experts,
969 expert_dim,
970 num_gates: tasks.len(),
971 },
972 tasks,
973 ..Default::default()
974 }
975 }
976
977 pub fn classification_task(name: &str, num_classes: usize) -> TaskConfig {
979 TaskConfig::new(
980 name,
981 TaskType::Classification {
982 num_classes,
983 use_class_weights: false,
984 },
985 )
986 }
987
988 pub fn regression_task(name: &str, output_dim: usize) -> TaskConfig {
990 TaskConfig::new(
991 name,
992 TaskType::Regression {
993 output_dim,
994 loss_type: RegressionLossType::MSE,
995 },
996 )
997 }
998
999 pub fn mlm_auxiliary_task(weight: f32) -> AuxiliaryTaskConfig {
1001 AuxiliaryTaskConfig {
1002 name: "mlm".to_string(),
1003 auxiliary_type: AuxiliaryType::MaskedLanguageModeling,
1004 weight,
1005 frequency: AuxiliaryTaskFrequency::EveryNSteps(10),
1006 }
1007 }
1008
1009 pub fn compute_task_similarity(
1011 task_performances: &HashMap<String, Vec<f32>>,
1012 ) -> HashMap<(String, String), f32> {
1013 let mut similarities = HashMap::new();
1014 let tasks: Vec<String> = task_performances.keys().cloned().collect();
1015
1016 for i in 0..tasks.len() {
1017 for j in i + 1..tasks.len() {
1018 let task1 = &tasks[i];
1019 let task2 = &tasks[j];
1020
1021 if let (Some(perf1), Some(perf2)) =
1022 (task_performances.get(task1), task_performances.get(task2))
1023 {
1024 let similarity = compute_correlation(perf1, perf2);
1025 similarities.insert((task1.clone(), task2.clone()), similarity);
1026 similarities.insert((task2.clone(), task1.clone()), similarity);
1027 }
1028 }
1029 }
1030
1031 similarities
1032 }
1033
1034 pub fn compute_correlation(seq1: &[f32], seq2: &[f32]) -> f32 {
1036 if seq1.len() != seq2.len() || seq1.is_empty() {
1037 return 0.0;
1038 }
1039
1040 let n = seq1.len() as f32;
1041 let mean1 = seq1.iter().sum::<f32>() / n;
1042 let mean2 = seq2.iter().sum::<f32>() / n;
1043
1044 let mut numerator = 0.0;
1045 let mut denom1 = 0.0;
1046 let mut denom2 = 0.0;
1047
1048 for i in 0..seq1.len() {
1049 let diff1 = seq1[i] - mean1;
1050 let diff2 = seq2[i] - mean2;
1051 numerator += diff1 * diff2;
1052 denom1 += diff1 * diff1;
1053 denom2 += diff2 * diff2;
1054 }
1055
1056 if denom1 * denom2 > 0.0 {
1057 numerator / (denom1 * denom2).sqrt()
1058 } else {
1059 0.0
1060 }
1061 }
1062
1063 pub fn analyze_mtl_effectiveness(
1065 single_task_performances: &HashMap<String, f32>,
1066 multi_task_performances: &HashMap<String, f32>,
1067 ) -> MTLAnalysis {
1068 let mut positive_transfer_tasks = Vec::new();
1069 let mut negative_transfer_tasks = Vec::new();
1070 let mut total_improvement = 0.0;
1071 let mut num_tasks = 0;
1072
1073 for (task_name, &mtl_perf) in multi_task_performances {
1074 if let Some(&single_perf) = single_task_performances.get(task_name) {
1075 let improvement = mtl_perf - single_perf;
1076 total_improvement += improvement;
1077 num_tasks += 1;
1078
1079 if improvement > 0.0 {
1080 positive_transfer_tasks.push(task_name.clone());
1081 } else if improvement < 0.0 {
1082 negative_transfer_tasks.push(task_name.clone());
1083 }
1084 }
1085 }
1086
1087 let average_improvement =
1088 if num_tasks > 0 { total_improvement / num_tasks as f32 } else { 0.0 };
1089
1090 MTLAnalysis {
1091 average_improvement,
1092 positive_transfer_tasks,
1093 negative_transfer_tasks,
1094 num_tasks,
1095 }
1096 }
1097}
1098
1099#[derive(Debug, Clone)]
1101pub struct MTLAnalysis {
1102 pub average_improvement: f32,
1103 pub positive_transfer_tasks: Vec<String>,
1104 pub negative_transfer_tasks: Vec<String>,
1105 pub num_tasks: usize,
1106}
1107
1108#[cfg(test)]
1109mod tests {
1110 use super::*;
1111
1112 #[test]
1113 fn test_mtl_config_default() {
1114 let config = MTLConfig::default();
1115 assert_eq!(config.tasks.len(), 0);
1116 assert!(!config.use_task_embeddings);
1117 assert!(!config.use_auxiliary_tasks);
1118
1119 if let MTLArchitecture::HardParameterSharing {
1120 shared_layers,
1121 task_specific_layers,
1122 } = config.architecture
1123 {
1124 assert_eq!(shared_layers, 8);
1125 assert_eq!(task_specific_layers, 2);
1126 } else {
1127 panic!("Expected HardParameterSharing architecture");
1128 }
1129 }
1130
1131 #[test]
1132 fn test_task_config() {
1133 let task = TaskConfig::new(
1134 "test",
1135 TaskType::Classification {
1136 num_classes: 10,
1137 use_class_weights: false,
1138 },
1139 );
1140
1141 assert_eq!(task.name, "test");
1142 assert_eq!(task.weight, 1.0);
1143 assert!(!task.is_main_task);
1144
1145 let weighted_task = task.with_weight(2.0);
1146 assert_eq!(weighted_task.weight, 2.0);
1147 }
1148
1149 #[test]
1150 fn test_classification_task_util() {
1151 let task = utils::classification_task("sentiment", 3);
1152 assert_eq!(task.name, "sentiment");
1153
1154 if let TaskType::Classification { num_classes, .. } = task.task_type {
1155 assert_eq!(num_classes, 3);
1156 } else {
1157 panic!("Expected Classification task type");
1158 }
1159 }
1160
1161 #[test]
1162 fn test_regression_task_util() {
1163 let task = utils::regression_task("score", 1);
1164 assert_eq!(task.name, "score");
1165
1166 if let TaskType::Regression { output_dim, .. } = task.task_type {
1167 assert_eq!(output_dim, 1);
1168 } else {
1169 panic!("Expected Regression task type");
1170 }
1171 }
1172
1173 #[test]
1174 fn test_hard_parameter_sharing_config() {
1175 let tasks = vec![
1176 utils::classification_task("task1", 5),
1177 utils::regression_task("task2", 1),
1178 ];
1179
1180 let config = utils::hard_parameter_sharing_config(tasks, 6, 2);
1181 assert_eq!(config.tasks.len(), 2);
1182
1183 if let MTLArchitecture::HardParameterSharing {
1184 shared_layers,
1185 task_specific_layers,
1186 } = config.architecture
1187 {
1188 assert_eq!(shared_layers, 6);
1189 assert_eq!(task_specific_layers, 2);
1190 } else {
1191 panic!("Expected HardParameterSharing architecture");
1192 }
1193 }
1194
1195 #[test]
1196 fn test_soft_parameter_sharing_config() {
1197 let tasks = vec![utils::classification_task("task1", 5)];
1198 let config = utils::soft_parameter_sharing_config(tasks, 0.01);
1199
1200 if let MTLArchitecture::SoftParameterSharing {
1201 regularization_weight,
1202 ..
1203 } = config.architecture
1204 {
1205 assert_eq!(regularization_weight, 0.01);
1206 } else {
1207 panic!("Expected SoftParameterSharing architecture");
1208 }
1209 }
1210
1211 #[test]
1212 fn test_mmoe_config() {
1213 let tasks = vec![
1214 utils::classification_task("task1", 5),
1215 utils::classification_task("task2", 3),
1216 ];
1217
1218 let config = utils::mmoe_config(tasks, 4, 128);
1219
1220 if let MTLArchitecture::MultiGateMixtureOfExperts {
1221 num_experts,
1222 expert_dim,
1223 num_gates,
1224 } = config.architecture
1225 {
1226 assert_eq!(num_experts, 4);
1227 assert_eq!(expert_dim, 128);
1228 assert_eq!(num_gates, 2);
1229 } else {
1230 panic!("Expected MultiGateMixtureOfExperts architecture");
1231 }
1232 }
1233
1234 #[test]
1235 fn test_mlm_auxiliary_task() {
1236 let aux_task = utils::mlm_auxiliary_task(0.1);
1237 assert_eq!(aux_task.name, "mlm");
1238 assert_eq!(aux_task.weight, 0.1);
1239
1240 if let AuxiliaryType::MaskedLanguageModeling = aux_task.auxiliary_type {
1241 } else {
1243 panic!("Expected MaskedLanguageModeling auxiliary type");
1244 }
1245 }
1246
1247 #[test]
1248 fn test_compute_correlation() {
1249 let seq1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1250 let seq2 = vec![2.0, 4.0, 6.0, 8.0, 10.0]; let correlation = utils::compute_correlation(&seq1, &seq2);
1253 assert!((correlation - 1.0).abs() < 1e-6);
1254
1255 let seq3 = vec![5.0, 4.0, 3.0, 2.0, 1.0]; let correlation_neg = utils::compute_correlation(&seq1, &seq3);
1257 assert!((correlation_neg + 1.0).abs() < 1e-6);
1258 }
1259
1260 #[test]
1261 fn test_mtl_analysis() {
1262 let mut single_task = HashMap::new();
1263 single_task.insert("task1".to_string(), 0.8);
1264 single_task.insert("task2".to_string(), 0.7);
1265 single_task.insert("task3".to_string(), 0.6);
1266
1267 let mut multi_task = HashMap::new();
1268 multi_task.insert("task1".to_string(), 0.85); multi_task.insert("task2".to_string(), 0.65); multi_task.insert("task3".to_string(), 0.65); let analysis = utils::analyze_mtl_effectiveness(&single_task, &multi_task);
1273 assert_eq!(analysis.num_tasks, 3);
1274 assert_eq!(analysis.positive_transfer_tasks.len(), 2);
1275 assert_eq!(analysis.negative_transfer_tasks.len(), 1);
1276 assert!(analysis.positive_transfer_tasks.contains(&"task1".to_string()));
1277 assert!(analysis.negative_transfer_tasks.contains(&"task2".to_string()));
1278 }
1279}