1use crate::{Loss, TrainError, TrainResult};
10use scirs2_core::ndarray::{s, Array, ArrayView, Ix2};
11use std::collections::HashMap;
12
13#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum TaskWeightingStrategy {
16 Fixed,
18 DynamicTaskPrioritization,
20 GradNorm { alpha: f64 },
22 UncertaintyWeighting,
24}
25
26pub struct MultiTaskLoss {
28 pub task_losses: Vec<Box<dyn Loss>>,
30 pub task_weights: Vec<f64>,
32 pub strategy: TaskWeightingStrategy,
34 pub weight_lr: f64,
36 initial_losses: Option<Vec<f64>>,
38}
39
40impl MultiTaskLoss {
41 pub fn new_fixed(task_losses: Vec<Box<dyn Loss>>, task_weights: Vec<f64>) -> TrainResult<Self> {
47 if task_losses.len() != task_weights.len() {
48 return Err(TrainError::ConfigError(
49 "Number of losses must match number of weights".to_string(),
50 ));
51 }
52
53 if task_losses.is_empty() {
54 return Err(TrainError::ConfigError(
55 "Must have at least one task".to_string(),
56 ));
57 }
58
59 Ok(Self {
60 task_losses,
61 task_weights,
62 strategy: TaskWeightingStrategy::Fixed,
63 weight_lr: 0.0,
64 initial_losses: None,
65 })
66 }
67
68 pub fn new_dynamic(
75 task_losses: Vec<Box<dyn Loss>>,
76 strategy: TaskWeightingStrategy,
77 weight_lr: f64,
78 ) -> TrainResult<Self> {
79 if task_losses.is_empty() {
80 return Err(TrainError::ConfigError(
81 "Must have at least one task".to_string(),
82 ));
83 }
84
85 let n_tasks = task_losses.len();
86 let task_weights = vec![1.0 / n_tasks as f64; n_tasks];
87
88 Ok(Self {
89 task_losses,
90 task_weights,
91 strategy,
92 weight_lr,
93 initial_losses: None,
94 })
95 }
96
97 pub fn compute_multi_task(
107 &mut self,
108 predictions: &ArrayView<f64, Ix2>,
109 targets: &ArrayView<f64, Ix2>,
110 task_splits: &[usize],
111 ) -> TrainResult<f64> {
112 if task_splits.len() != self.task_losses.len() + 1 {
113 return Err(TrainError::LossError(format!(
114 "task_splits must have {} elements (n_tasks + 1)",
115 self.task_losses.len() + 1
116 )));
117 }
118
119 let mut task_losses_values = Vec::new();
120
121 for i in 0..self.task_losses.len() {
123 let start = task_splits[i];
124 let end = task_splits[i + 1];
125
126 let task_pred = predictions.slice(s![.., start..end]);
127 let task_target = targets.slice(s![.., start..end]);
128
129 let loss_value = self.task_losses[i].compute(&task_pred, &task_target)?;
130 task_losses_values.push(loss_value);
131 }
132
133 if self.initial_losses.is_none() {
135 self.initial_losses = Some(task_losses_values.clone());
136 }
137
138 self.update_weights(&task_losses_values)?;
140
141 let total_loss = task_losses_values
143 .iter()
144 .zip(self.task_weights.iter())
145 .map(|(loss, weight)| loss * weight)
146 .sum();
147
148 Ok(total_loss)
149 }
150
151 fn update_weights(&mut self, current_losses: &[f64]) -> TrainResult<()> {
153 match self.strategy {
154 TaskWeightingStrategy::Fixed => {
155 Ok(())
157 }
158 TaskWeightingStrategy::DynamicTaskPrioritization => {
159 let sum: f64 = current_losses.iter().sum();
162 if sum > 1e-8 {
163 for (i, &loss) in current_losses.iter().enumerate() {
164 self.task_weights[i] = loss / sum;
165 }
166 }
167 Ok(())
168 }
169 TaskWeightingStrategy::GradNorm { alpha } => {
170 if let Some(ref initial) = self.initial_losses {
173 let mut relative_rates = Vec::new();
174 for i in 0..current_losses.len() {
175 let rate = current_losses[i] / initial[i].max(1e-8);
176 relative_rates.push(rate);
177 }
178
179 let mean_rate: f64 =
180 relative_rates.iter().sum::<f64>() / relative_rates.len() as f64;
181
182 for (i, &rate) in relative_rates.iter().enumerate() {
184 let target_rate = mean_rate * self.task_weights[i].powf(alpha);
185 let adjustment = (target_rate / rate.max(1e-8)).ln();
186 self.task_weights[i] *= (self.weight_lr * adjustment).exp();
187 }
188
189 let sum: f64 = self.task_weights.iter().sum();
191 for w in &mut self.task_weights {
192 *w /= sum;
193 }
194 }
195 Ok(())
196 }
197 TaskWeightingStrategy::UncertaintyWeighting => {
198 Ok(())
202 }
203 }
204 }
205
206 pub fn get_weights(&self) -> &[f64] {
208 &self.task_weights
209 }
210
211 pub fn num_tasks(&self) -> usize {
213 self.task_losses.len()
214 }
215}
216
217pub struct PCGrad;
223
224impl PCGrad {
225 pub fn apply(
233 task_gradients: &[HashMap<String, Array<f64, Ix2>>],
234 ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
235 if task_gradients.is_empty() {
236 return Err(TrainError::OptimizerError(
237 "PCGrad requires at least one task".to_string(),
238 ));
239 }
240
241 let n_tasks = task_gradients.len();
242 if n_tasks == 1 {
243 return Ok(task_gradients[0].clone());
244 }
245
246 let param_names: Vec<String> = task_gradients[0].keys().cloned().collect();
248
249 let mut combined_gradients = HashMap::new();
250
251 for param_name in param_names {
253 let mut grads: Vec<&Array<f64, Ix2>> = Vec::new();
255 for task_grad in task_gradients {
256 if let Some(grad) = task_grad.get(¶m_name) {
257 grads.push(grad);
258 }
259 }
260
261 if grads.len() != n_tasks {
262 continue; }
264
265 let mut modified_grads: Vec<Array<f64, Ix2>> = Vec::new();
267
268 for (i, grad) in grads.iter().enumerate() {
269 let mut grad_i = (*grad).clone();
270
271 for (j, other_grad) in grads.iter().enumerate() {
273 if i == j {
274 continue;
275 }
276
277 let dot_product: f64 = grad_i
279 .iter()
280 .zip(other_grad.iter())
281 .map(|(a, b)| a * b)
282 .sum();
283
284 if dot_product < 0.0 {
286 let norm_j_sq: f64 = other_grad.iter().map(|x| x * x).sum();
287
288 if norm_j_sq > 1e-8 {
289 let scale = dot_product / norm_j_sq;
291 grad_i = &grad_i - &(*other_grad * scale);
292 }
293 }
294 }
295
296 modified_grads.push(grad_i);
297 }
298
299 let mut combined = Array::zeros(grads[0].raw_dim());
301 for grad in &modified_grads {
302 combined = &combined + grad;
303 }
304 combined.mapv_inplace(|x| x / n_tasks as f64);
305
306 combined_gradients.insert(param_name.clone(), combined);
307 }
308
309 Ok(combined_gradients)
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316 use crate::MseLoss;
317 use scirs2_core::array;
318
319 #[test]
320 fn test_multitask_loss_fixed() {
321 let losses: Vec<Box<dyn Loss>> = vec![Box::new(MseLoss), Box::new(MseLoss)];
322 let weights = vec![0.7, 0.3];
323
324 let mut mt_loss = MultiTaskLoss::new_fixed(losses, weights).unwrap();
325
326 let predictions = array![[1.0, 2.0, 3.0, 4.0]];
327 let targets = array![[1.5, 2.5, 2.5, 3.5]];
328 let task_splits = vec![0, 2, 4]; let loss = mt_loss
331 .compute_multi_task(&predictions.view(), &targets.view(), &task_splits)
332 .unwrap();
333
334 assert!(loss > 0.0);
335 assert_eq!(mt_loss.get_weights(), &[0.7, 0.3]);
336 }
337
338 #[test]
339 fn test_multitask_loss_dtp() {
340 let losses: Vec<Box<dyn Loss>> = vec![Box::new(MseLoss), Box::new(MseLoss)];
341
342 let mut mt_loss = MultiTaskLoss::new_dynamic(
343 losses,
344 TaskWeightingStrategy::DynamicTaskPrioritization,
345 0.01,
346 )
347 .unwrap();
348
349 let predictions = array![[1.0, 2.0, 10.0, 11.0]]; let targets = array![[1.5, 2.5, 2.0, 3.0]];
351 let task_splits = vec![0, 2, 4];
352
353 let _loss = mt_loss
354 .compute_multi_task(&predictions.view(), &targets.view(), &task_splits)
355 .unwrap();
356
357 let weights = mt_loss.get_weights();
359 assert!(weights[1] > weights[0], "Task 2 should have higher weight");
360 }
361
362 #[test]
363 fn test_pcgrad_no_conflict() {
364 let grad1 = array![[1.0, 2.0], [3.0, 4.0]];
366 let grad2 = array![[1.0, 2.0], [3.0, 4.0]];
367
368 let mut task_grads = vec![HashMap::new(), HashMap::new()];
369 task_grads[0].insert("param".to_string(), grad1);
370 task_grads[1].insert("param".to_string(), grad2);
371
372 let result = PCGrad::apply(&task_grads).unwrap();
373 let combined = result.get("param").unwrap();
374
375 assert!((combined[[0, 0]] - 1.0).abs() < 1e-6);
377 assert!((combined[[1, 1]] - 4.0).abs() < 1e-6);
378 }
379
380 #[test]
381 fn test_pcgrad_conflict() {
382 let grad1 = array![[1.0, 0.0]];
384 let grad2 = array![[-1.0, 0.0]]; let mut task_grads = vec![HashMap::new(), HashMap::new()];
387 task_grads[0].insert("param".to_string(), grad1);
388 task_grads[1].insert("param".to_string(), grad2);
389
390 let result = PCGrad::apply(&task_grads).unwrap();
391 let combined = result.get("param").unwrap();
392
393 assert!(combined[[0, 0]].abs() < 1.0); }
396
397 #[test]
398 fn test_multitask_invalid_splits() {
399 let losses: Vec<Box<dyn Loss>> = vec![Box::new(MseLoss), Box::new(MseLoss)];
400 let mut mt_loss = MultiTaskLoss::new_fixed(losses, vec![0.5, 0.5]).unwrap();
401
402 let predictions = array![[1.0, 2.0]];
403 let targets = array![[1.5, 2.5]];
404 let task_splits = vec![0, 1]; let result = mt_loss.compute_multi_task(&predictions.view(), &targets.view(), &task_splits);
407 assert!(result.is_err());
408 }
409}