Skip to main content

trustformers_models/
multi_task_learning.rs

1//! # Multi-Task Learning Framework
2//!
3//! This module provides a comprehensive framework for multi-task learning,
4//! enabling models to learn multiple related tasks simultaneously to improve
5//! generalization and efficiency.
6//!
7//! ## Features
8//!
9//! - **Multiple MTL Architectures**: Hard parameter sharing, soft parameter sharing, task-specific layers
10//! - **Loss Balancing**: Various strategies for balancing losses across tasks
11//! - **Task Weighting**: Dynamic and static task weight adjustment
12//! - **Auxiliary Tasks**: Support for auxiliary tasks to improve main task performance
13//! - **Task Clustering**: Grouping related tasks for better sharing
14//! - **Evaluation Metrics**: Specialized metrics for multi-task scenarios
15//!
16//! ## Usage
17//!
18//! ```rust,no_run
19//! use trustformers_models::multi_task_learning::{
20//!     MultiTaskLearningTrainer, MTLConfig, MTLArchitecture
21//! };
22//!
23//! let config = MTLConfig {
24//!     architecture: MTLArchitecture::HardParameterSharing {
25//!         shared_layers: 8,
26//!         task_specific_layers: 2,
27//!     },
28//!     loss_balancing: LossBalancingStrategy::DynamicWeightAverage,
29//!     tasks: vec![
30//!         TaskConfig::new("classification", TaskType::Classification { num_classes: 10 }),
31//!         TaskConfig::new("regression", TaskType::Regression { output_dim: 1 }),
32//!     ],
33//!     ..Default::default()
34//! };
35//!
36//! let mut trainer = MultiTaskLearningTrainer::new(config)?;
37//! trainer.train_multi_task(task_data)?;
38//! ```
39
40use serde::{Deserialize, Serialize};
41use std::collections::HashMap;
42use trustformers_core::{
43    errors::invalid_input,
44    layers::Linear,
45    tensor::Tensor,
46    traits::{Layer, Model},
47    Result,
48};
49
50/// Configuration for multi-task learning
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct MTLConfig {
53    /// Multi-task learning architecture
54    pub architecture: MTLArchitecture,
55    /// Strategy for balancing losses across tasks
56    pub loss_balancing: LossBalancingStrategy,
57    /// Task configurations
58    pub tasks: Vec<TaskConfig>,
59    /// Whether to use task embeddings
60    pub use_task_embeddings: bool,
61    /// Task embedding dimension
62    pub task_embedding_dim: usize,
63    /// Whether to use auxiliary tasks
64    pub use_auxiliary_tasks: bool,
65    /// Auxiliary task configurations
66    pub auxiliary_tasks: Vec<AuxiliaryTaskConfig>,
67    /// Task clustering configuration
68    pub task_clustering: Option<TaskClusteringConfig>,
69    /// Evaluation frequency for each task
70    pub evaluation_frequency: usize,
71    /// Whether to use task scheduling
72    pub use_task_scheduling: bool,
73    /// Task scheduling strategy
74    pub task_scheduling: TaskSchedulingStrategy,
75}
76
77impl Default for MTLConfig {
78    fn default() -> Self {
79        Self {
80            architecture: MTLArchitecture::HardParameterSharing {
81                shared_layers: 8,
82                task_specific_layers: 2,
83            },
84            loss_balancing: LossBalancingStrategy::EqualWeighting,
85            tasks: Vec::new(),
86            use_task_embeddings: false,
87            task_embedding_dim: 64,
88            use_auxiliary_tasks: false,
89            auxiliary_tasks: Vec::new(),
90            task_clustering: None,
91            evaluation_frequency: 1000,
92            use_task_scheduling: false,
93            task_scheduling: TaskSchedulingStrategy::RoundRobin,
94        }
95    }
96}
97
98/// Multi-task learning architectures
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub enum MTLArchitecture {
101    /// Hard parameter sharing - shared bottom layers, task-specific top layers
102    HardParameterSharing {
103        shared_layers: usize,
104        task_specific_layers: usize,
105    },
106    /// Soft parameter sharing - each task has its own parameters with regularization
107    SoftParameterSharing {
108        regularization_weight: f32,
109        regularization_type: RegularizationType,
110    },
111    /// Multi-gate mixture of experts
112    MultiGateMixtureOfExperts {
113        num_experts: usize,
114        expert_dim: usize,
115        num_gates: usize,
116    },
117    /// Cross-stitch networks
118    CrossStitchNetworks {
119        num_tasks: usize,
120        cross_stitch_layers: Vec<usize>,
121    },
122    /// Task routing networks
123    TaskRoutingNetworks {
124        num_routers: usize,
125        routing_dim: usize,
126    },
127    /// Progressive Neural Networks for MTL
128    ProgressiveNetworks {
129        lateral_connections: bool,
130        adapter_layers: bool,
131    },
132    /// Attention-based task sharing
133    AttentionBasedSharing {
134        attention_dim: usize,
135        num_attention_heads: usize,
136    },
137}
138
139/// Regularization types for soft parameter sharing
140#[derive(Debug, Clone, Serialize, Deserialize)]
141pub enum RegularizationType {
142    /// L2 regularization between task parameters
143    L2Regularization,
144    /// Trace norm regularization
145    TraceNorm,
146    /// Group LASSO
147    GroupLasso,
148    /// Elastic net
149    ElasticNet { l1_weight: f32, l2_weight: f32 },
150}
151
152/// Strategies for balancing losses across tasks
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub enum LossBalancingStrategy {
155    /// Equal weighting for all tasks
156    EqualWeighting,
157    /// Manual task weights
158    ManualWeighting { weights: Vec<f32> },
159    /// Uncertainty-based weighting
160    UncertaintyWeighting,
161    /// Dynamic weight average
162    DynamicWeightAverage,
163    /// GradNorm - gradient magnitude balancing
164    GradNorm { alpha: f32 },
165    /// Task-balanced sampling
166    TaskBalancedSampling,
167    /// Focal loss for hard tasks
168    FocalLoss { gamma: f32 },
169    /// Meta-learning based weighting
170    MetaLearning { meta_lr: f32 },
171}
172
173/// Task configuration
174#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct TaskConfig {
176    /// Task name/identifier
177    pub name: String,
178    /// Task type and parameters
179    pub task_type: TaskType,
180    /// Task weight (if using manual weighting)
181    pub weight: f32,
182    /// Task priority
183    pub priority: TaskPriority,
184    /// Whether this is the main task
185    pub is_main_task: bool,
186    /// Task-specific learning rate
187    pub learning_rate: Option<f32>,
188    /// Task-specific batch size
189    pub batch_size: Option<usize>,
190}
191
192impl TaskConfig {
193    pub fn new(name: &str, task_type: TaskType) -> Self {
194        Self {
195            name: name.to_string(),
196            task_type,
197            weight: 1.0,
198            priority: TaskPriority::Normal,
199            is_main_task: false,
200            learning_rate: None,
201            batch_size: None,
202        }
203    }
204
205    pub fn with_weight(mut self, weight: f32) -> Self {
206        self.weight = weight;
207        self
208    }
209
210    pub fn with_priority(mut self, priority: TaskPriority) -> Self {
211        self.priority = priority;
212        self
213    }
214
215    pub fn as_main_task(mut self) -> Self {
216        self.is_main_task = true;
217        self
218    }
219}
220
221/// Task types and their specific parameters
222#[derive(Debug, Clone, Serialize, Deserialize)]
223pub enum TaskType {
224    /// Classification task
225    Classification {
226        num_classes: usize,
227        use_class_weights: bool,
228    },
229    /// Regression task
230    Regression {
231        output_dim: usize,
232        loss_type: RegressionLossType,
233    },
234    /// Sequence labeling task
235    SequenceLabeling { num_labels: usize, use_crf: bool },
236    /// Generation task
237    Generation {
238        vocab_size: usize,
239        max_length: usize,
240    },
241    /// Ranking task
242    Ranking { ranking_type: RankingType },
243    /// Auxiliary task
244    Auxiliary { auxiliary_type: AuxiliaryType },
245}
246
247/// Regression loss types
248#[derive(Debug, Clone, Serialize, Deserialize)]
249pub enum RegressionLossType {
250    MSE,
251    MAE,
252    Huber { delta: f32 },
253    LogCosh,
254}
255
256/// Ranking task types
257#[derive(Debug, Clone, Serialize, Deserialize)]
258pub enum RankingType {
259    Pairwise,
260    Listwise,
261    Pointwise,
262}
263
264/// Auxiliary task types
265#[derive(Debug, Clone, Serialize, Deserialize)]
266pub enum AuxiliaryType {
267    LanguageModeling,
268    MaskedLanguageModeling,
269    NextSentencePrediction,
270    SentenceOrderPrediction,
271    WordOrderPrediction,
272    Custom { name: String },
273}
274
275/// Task priorities
276#[derive(Debug, Clone, Serialize, Deserialize)]
277pub enum TaskPriority {
278    Low,
279    Normal,
280    High,
281    Critical,
282}
283
284/// Auxiliary task configuration
285#[derive(Debug, Clone, Serialize, Deserialize)]
286pub struct AuxiliaryTaskConfig {
287    pub name: String,
288    pub auxiliary_type: AuxiliaryType,
289    pub weight: f32,
290    pub frequency: AuxiliaryTaskFrequency,
291}
292
293/// Frequency of auxiliary task training
294#[derive(Debug, Clone, Serialize, Deserialize)]
295pub enum AuxiliaryTaskFrequency {
296    /// Train every N main task steps
297    EveryNSteps(usize),
298    /// Train with probability P
299    WithProbability(f32),
300    /// Train continuously
301    Continuous,
302    /// Train only in certain epochs
303    EpochRange { start: usize, end: usize },
304}
305
306/// Task clustering configuration
307#[derive(Debug, Clone, Serialize, Deserialize)]
308pub struct TaskClusteringConfig {
309    pub clustering_method: ClusteringMethod,
310    pub num_clusters: usize,
311    pub update_frequency: usize,
312}
313
314/// Clustering methods for tasks
315#[derive(Debug, Clone, Serialize, Deserialize)]
316pub enum ClusteringMethod {
317    /// Cluster by gradient similarity
318    GradientSimilarity,
319    /// Cluster by task performance correlation
320    PerformanceCorrelation,
321    /// Cluster by data similarity
322    DataSimilarity,
323    /// Manual clustering
324    Manual { clusters: Vec<Vec<String>> },
325}
326
327/// Task scheduling strategies
328#[derive(Debug, Clone, Serialize, Deserialize)]
329pub enum TaskSchedulingStrategy {
330    /// Round-robin scheduling
331    RoundRobin,
332    /// Weighted sampling by task priority
333    WeightedSampling,
334    /// Performance-based scheduling
335    PerformanceBased,
336    /// Curriculum-based scheduling
337    CurriculumBased { difficulty_order: Vec<String> },
338    /// Random scheduling
339    Random,
340}
341
342/// Multi-task learning trainer
343pub struct MultiTaskLearningTrainer<M: Model> {
344    /// Base model (shared layers)
345    pub base_model: M,
346    /// Task-specific heads
347    pub task_heads: HashMap<String, TaskHead>,
348    /// Configuration
349    pub config: MTLConfig,
350    /// Task losses and weights
351    pub task_weights: HashMap<String, f32>,
352    /// Task performance history
353    pub task_performance: HashMap<String, Vec<f32>>,
354    /// Current training step
355    pub step_counter: usize,
356    /// Task scheduling state
357    pub scheduler_state: TaskSchedulerState,
358    /// Gradient statistics for balancing
359    pub gradient_stats: HashMap<String, GradientStats>,
360}
361
362impl<M: Model<Input = Tensor, Output = Tensor>> MultiTaskLearningTrainer<M> {
363    /// Create a new multi-task learning trainer
364    pub fn new(base_model: M, config: MTLConfig) -> Result<Self> {
365        let mut task_heads = HashMap::new();
366        let mut task_weights = HashMap::new();
367
368        // Initialize task heads
369        for task_config in &config.tasks {
370            let task_head = TaskHead::new(&task_config.task_type)?;
371            task_heads.insert(task_config.name.clone(), task_head);
372            task_weights.insert(task_config.name.clone(), task_config.weight);
373        }
374
375        let scheduler_state = TaskSchedulerState::new(&config.task_scheduling);
376
377        Ok(Self {
378            base_model,
379            task_heads,
380            config,
381            task_weights,
382            task_performance: HashMap::new(),
383            step_counter: 0,
384            scheduler_state,
385            gradient_stats: HashMap::new(),
386        })
387    }
388
389    /// Train on multiple tasks for one step
390    pub fn train_multi_task_step(
391        &mut self,
392        task_data: &HashMap<String, TaskBatch>,
393    ) -> Result<MultiTaskOutput> {
394        let mut task_losses = HashMap::new();
395        let mut task_accuracies = HashMap::new();
396        let mut total_loss = Tensor::zeros(&[1])?;
397
398        // Determine which tasks to train on this step
399        let active_tasks = self.get_active_tasks(task_data)?;
400
401        for task_name in &active_tasks {
402            if let Some(batch) = task_data.get(task_name) {
403                // Forward pass through shared layers
404                let shared_features = self.base_model.forward(batch.inputs.clone())?;
405
406                // Task-specific forward pass
407                let task_head = self
408                    .task_heads
409                    .get(task_name)
410                    .ok_or_else(|| anyhow::anyhow!("Task head not found: {}", task_name))?;
411
412                let task_outputs = task_head.forward(&shared_features)?;
413                let task_loss = self.compute_task_loss(task_name, &task_outputs, &batch.targets)?;
414                let task_accuracy =
415                    self.compute_task_accuracy(task_name, &task_outputs, &batch.targets)?;
416
417                task_losses.insert(task_name.clone(), task_loss.clone());
418                task_accuracies.insert(task_name.clone(), task_accuracy);
419
420                // Update task performance history
421                self.task_performance.entry(task_name.clone()).or_default().push(task_accuracy);
422            }
423        }
424
425        // Balance losses across tasks
426        let balanced_losses = self.balance_losses(&task_losses)?;
427
428        // Compute total loss
429        for (task_name, loss) in &balanced_losses {
430            let weight = self.task_weights.get(task_name).copied().unwrap_or(1.0);
431            total_loss = total_loss.add(&loss.scalar_mul(weight)?)?;
432        }
433
434        // Update task weights if using dynamic balancing
435        self.update_task_weights(&task_losses)?;
436
437        // Update auxiliary tasks if enabled
438        if self.config.use_auxiliary_tasks {
439            let aux_loss = self.compute_auxiliary_losses(task_data)?;
440            total_loss = total_loss.add(&aux_loss)?;
441        }
442
443        self.step_counter += 1;
444
445        Ok(MultiTaskOutput {
446            total_loss,
447            task_losses: task_losses
448                .into_iter()
449                .map(|(k, v)| (k, v.to_scalar().unwrap_or(0.0)))
450                .collect(),
451            task_accuracies,
452            active_tasks,
453            task_weights: self.task_weights.clone(),
454        })
455    }
456
457    /// Get active tasks for current training step
458    fn get_active_tasks(&mut self, task_data: &HashMap<String, TaskBatch>) -> Result<Vec<String>> {
459        match &self.config.task_scheduling {
460            TaskSchedulingStrategy::RoundRobin => {
461                let task_names: Vec<String> = task_data.keys().cloned().collect();
462                if task_names.is_empty() {
463                    return Ok(Vec::new());
464                }
465                let current_task = &task_names[self.step_counter % task_names.len()];
466                Ok(vec![current_task.clone()])
467            },
468            TaskSchedulingStrategy::WeightedSampling => {
469                // Sample tasks based on their weights/priorities
470                let mut weighted_tasks = Vec::new();
471                for task_config in &self.config.tasks {
472                    if task_data.contains_key(&task_config.name) {
473                        let weight = match task_config.priority {
474                            TaskPriority::Low => 0.5,
475                            TaskPriority::Normal => 1.0,
476                            TaskPriority::High => 2.0,
477                            TaskPriority::Critical => 3.0,
478                        };
479                        for _ in 0..(weight * 10.0) as usize {
480                            weighted_tasks.push(task_config.name.clone());
481                        }
482                    }
483                }
484                if weighted_tasks.is_empty() {
485                    return Ok(Vec::new());
486                }
487                let selected_task = &weighted_tasks[self.step_counter % weighted_tasks.len()];
488                Ok(vec![selected_task.clone()])
489            },
490            TaskSchedulingStrategy::Random => {
491                let task_names: Vec<String> = task_data.keys().cloned().collect();
492                if task_names.is_empty() {
493                    return Ok(Vec::new());
494                }
495                let random_idx = fastrand::usize(..task_names.len());
496                Ok(vec![task_names[random_idx].clone()])
497            },
498            _ => {
499                // For other strategies, train on all available tasks
500                Ok(task_data.keys().cloned().collect())
501            },
502        }
503    }
504
505    /// Balance losses across tasks
506    fn balance_losses(
507        &self,
508        task_losses: &HashMap<String, Tensor>,
509    ) -> Result<HashMap<String, Tensor>> {
510        match &self.config.loss_balancing {
511            LossBalancingStrategy::EqualWeighting => Ok(task_losses.clone()),
512            LossBalancingStrategy::ManualWeighting { weights } => {
513                let mut balanced = HashMap::new();
514                for (i, (task_name, loss)) in task_losses.iter().enumerate() {
515                    let weight = weights.get(i).copied().unwrap_or(1.0);
516                    balanced.insert(task_name.clone(), loss.scalar_mul(weight)?);
517                }
518                Ok(balanced)
519            },
520            LossBalancingStrategy::UncertaintyWeighting => {
521                // Implement uncertainty-based weighting
522                // This would typically involve learning task-specific uncertainty parameters
523                Ok(task_losses.clone()) // Simplified for now
524            },
525            LossBalancingStrategy::DynamicWeightAverage => {
526                // Use dynamic weight average algorithm
527                self.apply_dynamic_weight_average(task_losses)
528            },
529            LossBalancingStrategy::GradNorm { alpha } => {
530                // Apply GradNorm algorithm
531                self.apply_gradnorm(task_losses, *alpha)
532            },
533            _ => Ok(task_losses.clone()),
534        }
535    }
536
537    /// Apply dynamic weight average algorithm
538    fn apply_dynamic_weight_average(
539        &self,
540        task_losses: &HashMap<String, Tensor>,
541    ) -> Result<HashMap<String, Tensor>> {
542        // DWA uses relative descent rates to weight tasks
543        let mut balanced = HashMap::new();
544
545        if self.step_counter < 2 {
546            return Ok(task_losses.clone());
547        }
548
549        let temperature = 2.0; // DWA temperature parameter
550
551        for (task_name, loss) in task_losses {
552            // Get previous loss for this task
553            let prev_loss = self.get_previous_task_loss(task_name);
554            let current_loss = loss.to_scalar().unwrap_or(0.0);
555
556            let weight = if prev_loss > 0.0 {
557                let relative_decrease = current_loss / prev_loss;
558                (relative_decrease / temperature).exp()
559            } else {
560                1.0
561            };
562
563            balanced.insert(task_name.clone(), loss.clone().mul_scalar(weight)?);
564        }
565
566        Ok(balanced)
567    }
568
569    /// Apply GradNorm algorithm
570    fn apply_gradnorm(
571        &self,
572        task_losses: &HashMap<String, Tensor>,
573        _alpha: f32,
574    ) -> Result<HashMap<String, Tensor>> {
575        // GradNorm balances gradient magnitudes across tasks
576        // This is a simplified implementation
577        Ok(task_losses.clone())
578    }
579
580    /// Update task weights based on performance
581    fn update_task_weights(&mut self, task_losses: &HashMap<String, Tensor>) -> Result<()> {
582        match &self.config.loss_balancing {
583            LossBalancingStrategy::DynamicWeightAverage => {
584                // Update weights based on loss trends
585                for (task_name, loss) in task_losses {
586                    let current_loss = loss.to_scalar().unwrap_or(0.0);
587                    // Update internal weight tracking
588                    // This would be more sophisticated in practice
589                    if let Some(weight) = self.task_weights.get_mut(task_name) {
590                        *weight = (*weight * 0.9 + current_loss * 0.1).clamp(0.1, 10.0);
591                    }
592                }
593            },
594            _ => {
595                // Other strategies don't update weights dynamically
596            },
597        }
598        Ok(())
599    }
600
601    /// Get previous task loss for DWA
602    fn get_previous_task_loss(&self, _task_name: &str) -> f32 {
603        // This would get the loss from the previous step
604        // Simplified implementation
605        1.0
606    }
607
608    /// Compute auxiliary task losses
609    fn compute_auxiliary_losses(&self, task_data: &HashMap<String, TaskBatch>) -> Result<Tensor> {
610        let mut aux_loss: Tensor = Tensor::zeros(&[1])?;
611
612        for aux_config in &self.config.auxiliary_tasks {
613            if self.should_train_auxiliary_task(aux_config) {
614                if let Some(aux_data) = task_data.get(&aux_config.name) {
615                    let aux_task_loss: Tensor =
616                        self.compute_auxiliary_task_loss(aux_config, aux_data)?;
617                    let weighted_loss: Tensor = aux_task_loss.mul_scalar(aux_config.weight)?;
618                    aux_loss = aux_loss.add(&weighted_loss)?;
619                }
620            }
621        }
622
623        Ok(aux_loss)
624    }
625
626    /// Check if auxiliary task should be trained this step
627    fn should_train_auxiliary_task(&self, aux_config: &AuxiliaryTaskConfig) -> bool {
628        match &aux_config.frequency {
629            AuxiliaryTaskFrequency::EveryNSteps(n) => self.step_counter % n == 0,
630            AuxiliaryTaskFrequency::WithProbability(p) => fastrand::f32() < *p,
631            AuxiliaryTaskFrequency::Continuous => true,
632            AuxiliaryTaskFrequency::EpochRange { start, end } => {
633                let current_epoch = self.step_counter / 1000; // Simplified epoch calculation
634                current_epoch >= *start && current_epoch <= *end
635            },
636        }
637    }
638
639    /// Compute auxiliary task loss
640    fn compute_auxiliary_task_loss(
641        &self,
642        aux_config: &AuxiliaryTaskConfig,
643        data: &TaskBatch,
644    ) -> Result<Tensor> {
645        // Compute loss for auxiliary task
646        let shared_features: Tensor = self.base_model.forward(data.inputs.clone())?;
647
648        match &aux_config.auxiliary_type {
649            AuxiliaryType::LanguageModeling => {
650                // Compute language modeling loss
651                self.compute_lm_loss(&shared_features, &data.targets)
652            },
653            AuxiliaryType::MaskedLanguageModeling => {
654                // Compute MLM loss
655                self.compute_mlm_loss(&shared_features, &data.targets)
656            },
657            _ => {
658                // Other auxiliary tasks
659                Ok(Tensor::zeros(&[1])?)
660            },
661        }
662    }
663
664    /// Compute language modeling loss
665    fn compute_lm_loss(&self, _features: &Tensor, _targets: &Tensor) -> Result<Tensor> {
666        // Simplified LM loss computation
667        Tensor::zeros(&[1])
668    }
669
670    /// Compute masked language modeling loss
671    fn compute_mlm_loss(&self, _features: &Tensor, _targets: &Tensor) -> Result<Tensor> {
672        // Simplified MLM loss computation
673        Tensor::zeros(&[1])
674    }
675
676    /// Compute task-specific loss
677    fn compute_task_loss(
678        &self,
679        task_name: &str,
680        outputs: &Tensor,
681        targets: &Tensor,
682    ) -> Result<Tensor> {
683        let task_config = self
684            .config
685            .tasks
686            .iter()
687            .find(|t| t.name == task_name)
688            .ok_or_else(|| invalid_input(format!("Task not found: {}", task_name)))?;
689
690        match &task_config.task_type {
691            TaskType::Classification { .. } => {
692                // Cross-entropy loss
693                let log_probs = outputs.softmax(-1)?;
694                let nll_loss = targets.mul(&log_probs)?.sum(Some(vec![1]), false)?;
695                Ok(nll_loss.mean()?.mul_scalar(-1.0)?)
696            },
697            TaskType::Regression { loss_type, .. } => {
698                match loss_type {
699                    RegressionLossType::MSE => {
700                        let diff = outputs.sub(targets)?;
701                        Ok(diff.mul(&diff)?.mean()?)
702                    },
703                    RegressionLossType::MAE => {
704                        let diff = outputs.sub(targets)?;
705                        Ok(diff.abs()?.mean()?)
706                    },
707                    RegressionLossType::Huber { delta } => {
708                        let diff = outputs.sub(targets)?;
709                        let abs_diff = diff.abs()?;
710                        let small_loss = diff.mul(&diff)?.mul_scalar(0.5)?;
711                        let _large_loss =
712                            abs_diff.mul_scalar(*delta)?.sub_scalar(*delta * *delta * 0.5)?;
713                        // Simplified Huber loss approximation
714                        Ok(small_loss.mean()?)
715                    },
716                    _ => {
717                        // Other regression losses
718                        let diff = outputs.sub(targets)?;
719                        Ok(diff.mul(&diff)?.mean()?)
720                    },
721                }
722            },
723            _ => {
724                // Other task types
725                Ok(Tensor::zeros(&[1])?)
726            },
727        }
728    }
729
730    /// Compute task-specific accuracy
731    fn compute_task_accuracy(
732        &self,
733        task_name: &str,
734        outputs: &Tensor,
735        targets: &Tensor,
736    ) -> Result<f32> {
737        let task_config = self
738            .config
739            .tasks
740            .iter()
741            .find(|t| t.name == task_name)
742            .ok_or_else(|| invalid_input(format!("Task not found: {}", task_name)))?;
743
744        match &task_config.task_type {
745            TaskType::Classification { .. } => {
746                let predicted = outputs.argmax(-1)?;
747                let target_class = targets.argmax(-1)?;
748                let correct = (predicted.to_scalar().unwrap_or(-1.0)
749                    == target_class.to_scalar().unwrap_or(-2.0))
750                    as i32 as f32;
751                Ok(correct)
752            },
753            TaskType::Regression { .. } => {
754                // For regression, compute R² or similar metric
755                let diff = outputs.sub(targets)?;
756                let mse = diff.mul(&diff)?.mean()?;
757                let mean_targets = targets.mean()?;
758                let diff_from_mean = targets.sub(&mean_targets)?;
759                let variance = diff_from_mean.pow_scalar(2.0)?.mean()?;
760                let r_squared =
761                    1.0 - mse.to_scalar().unwrap_or(1.0) / variance.to_scalar().unwrap_or(1.0);
762                Ok(r_squared.max(0.0))
763            },
764            _ => Ok(0.0),
765        }
766    }
767
768    /// Evaluate all tasks
769    pub fn evaluate_all_tasks(
770        &self,
771        test_data: &HashMap<String, TaskBatch>,
772    ) -> Result<MultiTaskEvaluation> {
773        let mut task_evaluations = HashMap::new();
774
775        for (task_name, batch) in test_data {
776            if let Some(task_head) = self.task_heads.get(task_name) {
777                let shared_features = self.base_model.forward(batch.inputs.clone())?;
778                let task_outputs = task_head.forward(&shared_features)?;
779                let loss = self.compute_task_loss(task_name, &task_outputs, &batch.targets)?;
780                let accuracy =
781                    self.compute_task_accuracy(task_name, &task_outputs, &batch.targets)?;
782
783                task_evaluations.insert(
784                    task_name.clone(),
785                    TaskEvaluation {
786                        task_name: task_name.clone(),
787                        loss: loss.to_scalar().unwrap_or(0.0),
788                        accuracy,
789                        num_examples: batch.inputs.shape()[0],
790                    },
791                );
792            }
793        }
794
795        let overall_accuracy = if !task_evaluations.is_empty() {
796            task_evaluations.values().map(|e| e.accuracy).sum::<f32>()
797                / task_evaluations.len() as f32
798        } else {
799            0.0
800        };
801
802        Ok(MultiTaskEvaluation {
803            task_evaluations,
804            overall_accuracy,
805            step: self.step_counter,
806        })
807    }
808
809    /// Get multi-task learning statistics
810    pub fn get_mtl_stats(&self) -> MTLStats {
811        MTLStats {
812            num_tasks: self.config.tasks.len(),
813            task_weights: self.task_weights.clone(),
814            step_counter: self.step_counter,
815            architecture: self.config.architecture.clone(),
816            loss_balancing: self.config.loss_balancing.clone(),
817        }
818    }
819}
820
821/// Task-specific neural network head
822pub struct TaskHead {
823    layers: Vec<Linear>,
824    #[allow(dead_code)]
825    task_type: TaskType,
826}
827
828impl TaskHead {
829    pub fn new(task_type: &TaskType) -> Result<Self> {
830        let mut layers = Vec::new();
831
832        match task_type {
833            TaskType::Classification { num_classes, .. } => {
834                // Simple classification head
835                layers.push(Linear::new(768, *num_classes, true)); // Assuming 768 hidden size
836            },
837            TaskType::Regression { output_dim, .. } => {
838                layers.push(Linear::new(768, *output_dim, true));
839            },
840            _ => {
841                // Default head
842                layers.push(Linear::new(768, 768, true));
843            },
844        }
845
846        Ok(Self {
847            layers,
848            task_type: task_type.clone(),
849        })
850    }
851
852    pub fn forward(&self, input: &Tensor) -> Result<Tensor> {
853        let mut output = input.clone();
854        for layer in &self.layers {
855            output = layer.forward(output)?;
856        }
857        Ok(output)
858    }
859}
860
861/// Training data batch for a specific task
862#[derive(Debug, Clone)]
863pub struct TaskBatch {
864    pub inputs: Tensor,
865    pub targets: Tensor,
866    pub task_name: String,
867}
868
869/// Task scheduler state
870pub struct TaskSchedulerState {
871    pub current_task_index: usize,
872    pub task_counters: HashMap<String, usize>,
873}
874
875impl TaskSchedulerState {
876    pub fn new(_strategy: &TaskSchedulingStrategy) -> Self {
877        Self {
878            current_task_index: 0,
879            task_counters: HashMap::new(),
880        }
881    }
882}
883
884/// Gradient statistics for task balancing
885#[derive(Debug, Clone)]
886pub struct GradientStats {
887    pub gradient_norm: f32,
888    pub gradient_variance: f32,
889    pub update_count: usize,
890}
891
892/// Output from multi-task training step
893#[derive(Debug, Clone)]
894pub struct MultiTaskOutput {
895    pub total_loss: Tensor,
896    pub task_losses: HashMap<String, f32>,
897    pub task_accuracies: HashMap<String, f32>,
898    pub active_tasks: Vec<String>,
899    pub task_weights: HashMap<String, f32>,
900}
901
902/// Task evaluation results
903#[derive(Debug, Clone)]
904pub struct TaskEvaluation {
905    pub task_name: String,
906    pub loss: f32,
907    pub accuracy: f32,
908    pub num_examples: usize,
909}
910
911/// Multi-task evaluation results
912#[derive(Debug, Clone)]
913pub struct MultiTaskEvaluation {
914    pub task_evaluations: HashMap<String, TaskEvaluation>,
915    pub overall_accuracy: f32,
916    pub step: usize,
917}
918
919/// Multi-task learning statistics
920#[derive(Debug, Clone)]
921pub struct MTLStats {
922    pub num_tasks: usize,
923    pub task_weights: HashMap<String, f32>,
924    pub step_counter: usize,
925    pub architecture: MTLArchitecture,
926    pub loss_balancing: LossBalancingStrategy,
927}
928
929/// Utilities for multi-task learning
930pub mod utils {
931    use super::*;
932
933    /// Create a simple hard parameter sharing configuration
934    pub fn hard_parameter_sharing_config(
935        tasks: Vec<TaskConfig>,
936        shared_layers: usize,
937        task_specific_layers: usize,
938    ) -> MTLConfig {
939        MTLConfig {
940            architecture: MTLArchitecture::HardParameterSharing {
941                shared_layers,
942                task_specific_layers,
943            },
944            tasks,
945            ..Default::default()
946        }
947    }
948
949    /// Create a soft parameter sharing configuration
950    pub fn soft_parameter_sharing_config(
951        tasks: Vec<TaskConfig>,
952        regularization_weight: f32,
953    ) -> MTLConfig {
954        MTLConfig {
955            architecture: MTLArchitecture::SoftParameterSharing {
956                regularization_weight,
957                regularization_type: RegularizationType::L2Regularization,
958            },
959            tasks,
960            ..Default::default()
961        }
962    }
963
964    /// Create a multi-gate mixture of experts configuration
965    pub fn mmoe_config(tasks: Vec<TaskConfig>, num_experts: usize, expert_dim: usize) -> MTLConfig {
966        MTLConfig {
967            architecture: MTLArchitecture::MultiGateMixtureOfExperts {
968                num_experts,
969                expert_dim,
970                num_gates: tasks.len(),
971            },
972            tasks,
973            ..Default::default()
974        }
975    }
976
977    /// Create task configuration for classification
978    pub fn classification_task(name: &str, num_classes: usize) -> TaskConfig {
979        TaskConfig::new(
980            name,
981            TaskType::Classification {
982                num_classes,
983                use_class_weights: false,
984            },
985        )
986    }
987
988    /// Create task configuration for regression
989    pub fn regression_task(name: &str, output_dim: usize) -> TaskConfig {
990        TaskConfig::new(
991            name,
992            TaskType::Regression {
993                output_dim,
994                loss_type: RegressionLossType::MSE,
995            },
996        )
997    }
998
999    /// Create auxiliary task configuration for MLM
1000    pub fn mlm_auxiliary_task(weight: f32) -> AuxiliaryTaskConfig {
1001        AuxiliaryTaskConfig {
1002            name: "mlm".to_string(),
1003            auxiliary_type: AuxiliaryType::MaskedLanguageModeling,
1004            weight,
1005            frequency: AuxiliaryTaskFrequency::EveryNSteps(10),
1006        }
1007    }
1008
1009    /// Compute task similarity matrix
1010    pub fn compute_task_similarity(
1011        task_performances: &HashMap<String, Vec<f32>>,
1012    ) -> HashMap<(String, String), f32> {
1013        let mut similarities = HashMap::new();
1014        let tasks: Vec<String> = task_performances.keys().cloned().collect();
1015
1016        for i in 0..tasks.len() {
1017            for j in i + 1..tasks.len() {
1018                let task1 = &tasks[i];
1019                let task2 = &tasks[j];
1020
1021                if let (Some(perf1), Some(perf2)) =
1022                    (task_performances.get(task1), task_performances.get(task2))
1023                {
1024                    let similarity = compute_correlation(perf1, perf2);
1025                    similarities.insert((task1.clone(), task2.clone()), similarity);
1026                    similarities.insert((task2.clone(), task1.clone()), similarity);
1027                }
1028            }
1029        }
1030
1031        similarities
1032    }
1033
1034    /// Compute correlation between two performance sequences
1035    pub fn compute_correlation(seq1: &[f32], seq2: &[f32]) -> f32 {
1036        if seq1.len() != seq2.len() || seq1.is_empty() {
1037            return 0.0;
1038        }
1039
1040        let n = seq1.len() as f32;
1041        let mean1 = seq1.iter().sum::<f32>() / n;
1042        let mean2 = seq2.iter().sum::<f32>() / n;
1043
1044        let mut numerator = 0.0;
1045        let mut denom1 = 0.0;
1046        let mut denom2 = 0.0;
1047
1048        for i in 0..seq1.len() {
1049            let diff1 = seq1[i] - mean1;
1050            let diff2 = seq2[i] - mean2;
1051            numerator += diff1 * diff2;
1052            denom1 += diff1 * diff1;
1053            denom2 += diff2 * diff2;
1054        }
1055
1056        if denom1 * denom2 > 0.0 {
1057            numerator / (denom1 * denom2).sqrt()
1058        } else {
1059            0.0
1060        }
1061    }
1062
1063    /// Analyze multi-task learning effectiveness
1064    pub fn analyze_mtl_effectiveness(
1065        single_task_performances: &HashMap<String, f32>,
1066        multi_task_performances: &HashMap<String, f32>,
1067    ) -> MTLAnalysis {
1068        let mut positive_transfer_tasks = Vec::new();
1069        let mut negative_transfer_tasks = Vec::new();
1070        let mut total_improvement = 0.0;
1071        let mut num_tasks = 0;
1072
1073        for (task_name, &mtl_perf) in multi_task_performances {
1074            if let Some(&single_perf) = single_task_performances.get(task_name) {
1075                let improvement = mtl_perf - single_perf;
1076                total_improvement += improvement;
1077                num_tasks += 1;
1078
1079                if improvement > 0.0 {
1080                    positive_transfer_tasks.push(task_name.clone());
1081                } else if improvement < 0.0 {
1082                    negative_transfer_tasks.push(task_name.clone());
1083                }
1084            }
1085        }
1086
1087        let average_improvement =
1088            if num_tasks > 0 { total_improvement / num_tasks as f32 } else { 0.0 };
1089
1090        MTLAnalysis {
1091            average_improvement,
1092            positive_transfer_tasks,
1093            negative_transfer_tasks,
1094            num_tasks,
1095        }
1096    }
1097}
1098
1099/// Analysis of multi-task learning effectiveness
1100#[derive(Debug, Clone)]
1101pub struct MTLAnalysis {
1102    pub average_improvement: f32,
1103    pub positive_transfer_tasks: Vec<String>,
1104    pub negative_transfer_tasks: Vec<String>,
1105    pub num_tasks: usize,
1106}
1107
1108#[cfg(test)]
1109mod tests {
1110    use super::*;
1111
1112    #[test]
1113    fn test_mtl_config_default() {
1114        let config = MTLConfig::default();
1115        assert_eq!(config.tasks.len(), 0);
1116        assert!(!config.use_task_embeddings);
1117        assert!(!config.use_auxiliary_tasks);
1118
1119        if let MTLArchitecture::HardParameterSharing {
1120            shared_layers,
1121            task_specific_layers,
1122        } = config.architecture
1123        {
1124            assert_eq!(shared_layers, 8);
1125            assert_eq!(task_specific_layers, 2);
1126        } else {
1127            panic!("Expected HardParameterSharing architecture");
1128        }
1129    }
1130
1131    #[test]
1132    fn test_task_config() {
1133        let task = TaskConfig::new(
1134            "test",
1135            TaskType::Classification {
1136                num_classes: 10,
1137                use_class_weights: false,
1138            },
1139        );
1140
1141        assert_eq!(task.name, "test");
1142        assert_eq!(task.weight, 1.0);
1143        assert!(!task.is_main_task);
1144
1145        let weighted_task = task.with_weight(2.0);
1146        assert_eq!(weighted_task.weight, 2.0);
1147    }
1148
1149    #[test]
1150    fn test_classification_task_util() {
1151        let task = utils::classification_task("sentiment", 3);
1152        assert_eq!(task.name, "sentiment");
1153
1154        if let TaskType::Classification { num_classes, .. } = task.task_type {
1155            assert_eq!(num_classes, 3);
1156        } else {
1157            panic!("Expected Classification task type");
1158        }
1159    }
1160
1161    #[test]
1162    fn test_regression_task_util() {
1163        let task = utils::regression_task("score", 1);
1164        assert_eq!(task.name, "score");
1165
1166        if let TaskType::Regression { output_dim, .. } = task.task_type {
1167            assert_eq!(output_dim, 1);
1168        } else {
1169            panic!("Expected Regression task type");
1170        }
1171    }
1172
1173    #[test]
1174    fn test_hard_parameter_sharing_config() {
1175        let tasks = vec![
1176            utils::classification_task("task1", 5),
1177            utils::regression_task("task2", 1),
1178        ];
1179
1180        let config = utils::hard_parameter_sharing_config(tasks, 6, 2);
1181        assert_eq!(config.tasks.len(), 2);
1182
1183        if let MTLArchitecture::HardParameterSharing {
1184            shared_layers,
1185            task_specific_layers,
1186        } = config.architecture
1187        {
1188            assert_eq!(shared_layers, 6);
1189            assert_eq!(task_specific_layers, 2);
1190        } else {
1191            panic!("Expected HardParameterSharing architecture");
1192        }
1193    }
1194
1195    #[test]
1196    fn test_soft_parameter_sharing_config() {
1197        let tasks = vec![utils::classification_task("task1", 5)];
1198        let config = utils::soft_parameter_sharing_config(tasks, 0.01);
1199
1200        if let MTLArchitecture::SoftParameterSharing {
1201            regularization_weight,
1202            ..
1203        } = config.architecture
1204        {
1205            assert_eq!(regularization_weight, 0.01);
1206        } else {
1207            panic!("Expected SoftParameterSharing architecture");
1208        }
1209    }
1210
1211    #[test]
1212    fn test_mmoe_config() {
1213        let tasks = vec![
1214            utils::classification_task("task1", 5),
1215            utils::classification_task("task2", 3),
1216        ];
1217
1218        let config = utils::mmoe_config(tasks, 4, 128);
1219
1220        if let MTLArchitecture::MultiGateMixtureOfExperts {
1221            num_experts,
1222            expert_dim,
1223            num_gates,
1224        } = config.architecture
1225        {
1226            assert_eq!(num_experts, 4);
1227            assert_eq!(expert_dim, 128);
1228            assert_eq!(num_gates, 2);
1229        } else {
1230            panic!("Expected MultiGateMixtureOfExperts architecture");
1231        }
1232    }
1233
1234    #[test]
1235    fn test_mlm_auxiliary_task() {
1236        let aux_task = utils::mlm_auxiliary_task(0.1);
1237        assert_eq!(aux_task.name, "mlm");
1238        assert_eq!(aux_task.weight, 0.1);
1239
1240        if let AuxiliaryType::MaskedLanguageModeling = aux_task.auxiliary_type {
1241            // Expected
1242        } else {
1243            panic!("Expected MaskedLanguageModeling auxiliary type");
1244        }
1245    }
1246
1247    #[test]
1248    fn test_compute_correlation() {
1249        let seq1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1250        let seq2 = vec![2.0, 4.0, 6.0, 8.0, 10.0]; // Perfect positive correlation
1251
1252        let correlation = utils::compute_correlation(&seq1, &seq2);
1253        assert!((correlation - 1.0).abs() < 1e-6);
1254
1255        let seq3 = vec![5.0, 4.0, 3.0, 2.0, 1.0]; // Perfect negative correlation
1256        let correlation_neg = utils::compute_correlation(&seq1, &seq3);
1257        assert!((correlation_neg + 1.0).abs() < 1e-6);
1258    }
1259
1260    #[test]
1261    fn test_mtl_analysis() {
1262        let mut single_task = HashMap::new();
1263        single_task.insert("task1".to_string(), 0.8);
1264        single_task.insert("task2".to_string(), 0.7);
1265        single_task.insert("task3".to_string(), 0.6);
1266
1267        let mut multi_task = HashMap::new();
1268        multi_task.insert("task1".to_string(), 0.85); // Positive transfer
1269        multi_task.insert("task2".to_string(), 0.65); // Negative transfer
1270        multi_task.insert("task3".to_string(), 0.65); // Positive transfer
1271
1272        let analysis = utils::analyze_mtl_effectiveness(&single_task, &multi_task);
1273        assert_eq!(analysis.num_tasks, 3);
1274        assert_eq!(analysis.positive_transfer_tasks.len(), 2);
1275        assert_eq!(analysis.negative_transfer_tasks.len(), 1);
1276        assert!(analysis.positive_transfer_tasks.contains(&"task1".to_string()));
1277        assert!(analysis.negative_transfer_tasks.contains(&"task2".to_string()));
1278    }
1279}