Skip to main content

tensorlogic_train/
multitask.rs

1//! Multi-task learning utilities for training with multiple objectives.
2//!
3//! This module provides utilities for multi-task learning, including:
4//! - Task weighting strategies
5//! - Multi-task loss composition
6//! - Gradient balancing techniques
7//! - Task-specific metrics tracking
8
9use crate::{Loss, TrainError, TrainResult};
10use scirs2_core::ndarray::{s, Array, ArrayView, Ix2};
11use std::collections::HashMap;
12
13/// Strategy for weighting multiple tasks.
14#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum TaskWeightingStrategy {
16    /// Fixed weights for each task.
17    Fixed,
18    /// Dynamic Task Prioritization (DTP) - weights based on task difficulty.
19    DynamicTaskPrioritization,
20    /// GradNorm - balances gradient magnitudes across tasks.
21    GradNorm { alpha: f64 },
22    /// Uncertainty weighting - learns task weights from homoscedastic uncertainty.
23    UncertaintyWeighting,
24}
25
26/// Multi-task loss that combines multiple losses with configurable weighting.
27pub struct MultiTaskLoss {
28    /// Individual task losses.
29    pub task_losses: Vec<Box<dyn Loss>>,
30    /// Task weights (automatically managed based on strategy).
31    pub task_weights: Vec<f64>,
32    /// Weighting strategy.
33    pub strategy: TaskWeightingStrategy,
34    /// Learning rate for weight updates (used in some strategies).
35    pub weight_lr: f64,
36    /// Initial loss values for normalization.
37    initial_losses: Option<Vec<f64>>,
38}
39
40impl MultiTaskLoss {
41    /// Create a new multi-task loss with fixed weights.
42    ///
43    /// # Arguments
44    /// * `task_losses` - Individual loss functions for each task
45    /// * `task_weights` - Fixed weights for each task (should sum to 1.0)
46    pub fn new_fixed(task_losses: Vec<Box<dyn Loss>>, task_weights: Vec<f64>) -> TrainResult<Self> {
47        if task_losses.len() != task_weights.len() {
48            return Err(TrainError::ConfigError(
49                "Number of losses must match number of weights".to_string(),
50            ));
51        }
52
53        if task_losses.is_empty() {
54            return Err(TrainError::ConfigError(
55                "Must have at least one task".to_string(),
56            ));
57        }
58
59        Ok(Self {
60            task_losses,
61            task_weights,
62            strategy: TaskWeightingStrategy::Fixed,
63            weight_lr: 0.0,
64            initial_losses: None,
65        })
66    }
67
68    /// Create a new multi-task loss with dynamic weighting.
69    ///
70    /// # Arguments
71    /// * `task_losses` - Individual loss functions for each task
72    /// * `strategy` - Weighting strategy to use
73    /// * `weight_lr` - Learning rate for weight updates
74    pub fn new_dynamic(
75        task_losses: Vec<Box<dyn Loss>>,
76        strategy: TaskWeightingStrategy,
77        weight_lr: f64,
78    ) -> TrainResult<Self> {
79        if task_losses.is_empty() {
80            return Err(TrainError::ConfigError(
81                "Must have at least one task".to_string(),
82            ));
83        }
84
85        let n_tasks = task_losses.len();
86        let task_weights = vec![1.0 / n_tasks as f64; n_tasks];
87
88        Ok(Self {
89            task_losses,
90            task_weights,
91            strategy,
92            weight_lr,
93            initial_losses: None,
94        })
95    }
96
97    /// Compute multi-task loss.
98    ///
99    /// # Arguments
100    /// * `predictions` - Predictions for all tasks (concatenated)
101    /// * `targets` - Targets for all tasks (concatenated)
102    /// * `task_splits` - Column indices where each task starts
103    ///
104    /// # Returns
105    /// Weighted sum of task losses
106    pub fn compute_multi_task(
107        &mut self,
108        predictions: &ArrayView<f64, Ix2>,
109        targets: &ArrayView<f64, Ix2>,
110        task_splits: &[usize],
111    ) -> TrainResult<f64> {
112        if task_splits.len() != self.task_losses.len() + 1 {
113            return Err(TrainError::LossError(format!(
114                "task_splits must have {} elements (n_tasks + 1)",
115                self.task_losses.len() + 1
116            )));
117        }
118
119        let mut task_losses_values = Vec::new();
120
121        // Compute individual task losses
122        for i in 0..self.task_losses.len() {
123            let start = task_splits[i];
124            let end = task_splits[i + 1];
125
126            let task_pred = predictions.slice(s![.., start..end]);
127            let task_target = targets.slice(s![.., start..end]);
128
129            let loss_value = self.task_losses[i].compute(&task_pred, &task_target)?;
130            task_losses_values.push(loss_value);
131        }
132
133        // Initialize on first call
134        if self.initial_losses.is_none() {
135            self.initial_losses = Some(task_losses_values.clone());
136        }
137
138        // Update task weights based on strategy
139        self.update_weights(&task_losses_values)?;
140
141        // Compute weighted sum
142        let total_loss = task_losses_values
143            .iter()
144            .zip(self.task_weights.iter())
145            .map(|(loss, weight)| loss * weight)
146            .sum();
147
148        Ok(total_loss)
149    }
150
151    /// Update task weights based on the selected strategy.
152    fn update_weights(&mut self, current_losses: &[f64]) -> TrainResult<()> {
153        match self.strategy {
154            TaskWeightingStrategy::Fixed => {
155                // Weights don't change
156                Ok(())
157            }
158            TaskWeightingStrategy::DynamicTaskPrioritization => {
159                // Weight tasks inversely to their performance
160                // Tasks with higher loss get higher weight
161                let sum: f64 = current_losses.iter().sum();
162                if sum > 1e-8 {
163                    for (i, &loss) in current_losses.iter().enumerate() {
164                        self.task_weights[i] = loss / sum;
165                    }
166                }
167                Ok(())
168            }
169            TaskWeightingStrategy::GradNorm { alpha } => {
170                // GradNorm: balance gradient magnitudes
171                // Simplified version - in practice, needs gradient information
172                if let Some(ref initial) = self.initial_losses {
173                    let mut relative_rates = Vec::new();
174                    for i in 0..current_losses.len() {
175                        let rate = current_losses[i] / initial[i].max(1e-8);
176                        relative_rates.push(rate);
177                    }
178
179                    let mean_rate: f64 =
180                        relative_rates.iter().sum::<f64>() / relative_rates.len() as f64;
181
182                    // Update weights to balance training rates
183                    for (i, &rate) in relative_rates.iter().enumerate() {
184                        let target_rate = mean_rate * self.task_weights[i].powf(alpha);
185                        let adjustment = (target_rate / rate.max(1e-8)).ln();
186                        self.task_weights[i] *= (self.weight_lr * adjustment).exp();
187                    }
188
189                    // Normalize weights
190                    let sum: f64 = self.task_weights.iter().sum();
191                    for w in &mut self.task_weights {
192                        *w /= sum;
193                    }
194                }
195                Ok(())
196            }
197            TaskWeightingStrategy::UncertaintyWeighting => {
198                // Uncertainty weighting: 1 / (2 * sigma^2) per task
199                // In practice, sigma would be learned parameters
200                // Here we use a simplified version based on loss variance
201                Ok(())
202            }
203        }
204    }
205
206    /// Get current task weights.
207    pub fn get_weights(&self) -> &[f64] {
208        &self.task_weights
209    }
210
211    /// Get number of tasks.
212    pub fn num_tasks(&self) -> usize {
213        self.task_losses.len()
214    }
215}
216
217/// PCGrad: Project conflicting gradients for multi-task learning.
218///
219/// This implements "Gradient Surgery for Multi-Task Learning" (Yu et al., 2020).
220/// When gradients from different tasks conflict (negative cosine similarity),
221/// it projects the conflicting gradient onto the normal plane of the other.
222pub struct PCGrad;
223
224impl PCGrad {
225    /// Apply PCGrad to balance gradients from multiple tasks.
226    ///
227    /// # Arguments
228    /// * `task_gradients` - Gradients for each task and parameter
229    ///
230    /// # Returns
231    /// Combined gradients with conflicts resolved
232    pub fn apply(
233        task_gradients: &[HashMap<String, Array<f64, Ix2>>],
234    ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
235        if task_gradients.is_empty() {
236            return Err(TrainError::OptimizerError(
237                "PCGrad requires at least one task".to_string(),
238            ));
239        }
240
241        let n_tasks = task_gradients.len();
242        if n_tasks == 1 {
243            return Ok(task_gradients[0].clone());
244        }
245
246        // Get all parameter names
247        let param_names: Vec<String> = task_gradients[0].keys().cloned().collect();
248
249        let mut combined_gradients = HashMap::new();
250
251        // For each parameter
252        for param_name in param_names {
253            // Collect gradients for this parameter from all tasks
254            let mut grads: Vec<&Array<f64, Ix2>> = Vec::new();
255            for task_grad in task_gradients {
256                if let Some(grad) = task_grad.get(&param_name) {
257                    grads.push(grad);
258                }
259            }
260
261            if grads.len() != n_tasks {
262                continue; // Skip if not all tasks have this parameter
263            }
264
265            // Apply PCGrad algorithm
266            let mut modified_grads: Vec<Array<f64, Ix2>> = Vec::new();
267
268            for (i, grad) in grads.iter().enumerate() {
269                let mut grad_i = (*grad).clone();
270
271                // Project onto normal plane of other tasks if conflicting
272                for (j, other_grad) in grads.iter().enumerate() {
273                    if i == j {
274                        continue;
275                    }
276
277                    // Compute cosine similarity
278                    let dot_product: f64 = grad_i
279                        .iter()
280                        .zip(other_grad.iter())
281                        .map(|(a, b)| a * b)
282                        .sum();
283
284                    // If negative (conflicting), project
285                    if dot_product < 0.0 {
286                        let norm_j_sq: f64 = other_grad.iter().map(|x| x * x).sum();
287
288                        if norm_j_sq > 1e-8 {
289                            // Project: g_i = g_i - (g_i ยท g_j / ||g_j||^2) * g_j
290                            let scale = dot_product / norm_j_sq;
291                            grad_i = &grad_i - &(*other_grad * scale);
292                        }
293                    }
294                }
295
296                modified_grads.push(grad_i);
297            }
298
299            // Average the modified gradients
300            let mut combined = Array::zeros(grads[0].raw_dim());
301            for grad in &modified_grads {
302                combined = &combined + grad;
303            }
304            combined.mapv_inplace(|x| x / n_tasks as f64);
305
306            combined_gradients.insert(param_name.clone(), combined);
307        }
308
309        Ok(combined_gradients)
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316    use crate::MseLoss;
317    use scirs2_core::array;
318
319    #[test]
320    fn test_multitask_loss_fixed() {
321        let losses: Vec<Box<dyn Loss>> = vec![Box::new(MseLoss), Box::new(MseLoss)];
322        let weights = vec![0.7, 0.3];
323
324        let mut mt_loss = MultiTaskLoss::new_fixed(losses, weights).unwrap();
325
326        let predictions = array![[1.0, 2.0, 3.0, 4.0]];
327        let targets = array![[1.5, 2.5, 2.5, 3.5]];
328        let task_splits = vec![0, 2, 4]; // Two tasks, 2 outputs each
329
330        let loss = mt_loss
331            .compute_multi_task(&predictions.view(), &targets.view(), &task_splits)
332            .unwrap();
333
334        assert!(loss > 0.0);
335        assert_eq!(mt_loss.get_weights(), &[0.7, 0.3]);
336    }
337
338    #[test]
339    fn test_multitask_loss_dtp() {
340        let losses: Vec<Box<dyn Loss>> = vec![Box::new(MseLoss), Box::new(MseLoss)];
341
342        let mut mt_loss = MultiTaskLoss::new_dynamic(
343            losses,
344            TaskWeightingStrategy::DynamicTaskPrioritization,
345            0.01,
346        )
347        .unwrap();
348
349        let predictions = array![[1.0, 2.0, 10.0, 11.0]]; // Second task has higher error
350        let targets = array![[1.5, 2.5, 2.0, 3.0]];
351        let task_splits = vec![0, 2, 4];
352
353        let _loss = mt_loss
354            .compute_multi_task(&predictions.view(), &targets.view(), &task_splits)
355            .unwrap();
356
357        // DTP should give more weight to the task with higher loss
358        let weights = mt_loss.get_weights();
359        assert!(weights[1] > weights[0], "Task 2 should have higher weight");
360    }
361
362    #[test]
363    fn test_pcgrad_no_conflict() {
364        // When gradients align, PCGrad should average them
365        let grad1 = array![[1.0, 2.0], [3.0, 4.0]];
366        let grad2 = array![[1.0, 2.0], [3.0, 4.0]];
367
368        let mut task_grads = vec![HashMap::new(), HashMap::new()];
369        task_grads[0].insert("param".to_string(), grad1);
370        task_grads[1].insert("param".to_string(), grad2);
371
372        let result = PCGrad::apply(&task_grads).unwrap();
373        let combined = result.get("param").unwrap();
374
375        // Should be the average
376        assert!((combined[[0, 0]] - 1.0).abs() < 1e-6);
377        assert!((combined[[1, 1]] - 4.0).abs() < 1e-6);
378    }
379
380    #[test]
381    fn test_pcgrad_conflict() {
382        // When gradients conflict, PCGrad should resolve them
383        let grad1 = array![[1.0, 0.0]];
384        let grad2 = array![[-1.0, 0.0]]; // Opposite direction
385
386        let mut task_grads = vec![HashMap::new(), HashMap::new()];
387        task_grads[0].insert("param".to_string(), grad1);
388        task_grads[1].insert("param".to_string(), grad2);
389
390        let result = PCGrad::apply(&task_grads).unwrap();
391        let combined = result.get("param").unwrap();
392
393        // Conflicting gradients should be projected
394        assert!(combined[[0, 0]].abs() < 1.0); // Should be reduced
395    }
396
397    #[test]
398    fn test_multitask_invalid_splits() {
399        let losses: Vec<Box<dyn Loss>> = vec![Box::new(MseLoss), Box::new(MseLoss)];
400        let mut mt_loss = MultiTaskLoss::new_fixed(losses, vec![0.5, 0.5]).unwrap();
401
402        let predictions = array![[1.0, 2.0]];
403        let targets = array![[1.5, 2.5]];
404        let task_splits = vec![0, 1]; // Wrong: should have 3 elements for 2 tasks
405
406        let result = mt_loss.compute_multi_task(&predictions.view(), &targets.view(), &task_splits);
407        assert!(result.is_err());
408    }
409}