sklears_multioutput/optimization/
joint_loss_optimization.rs

1//! Joint Loss Optimization for Multi-Output Learning
2//!
3//! This module provides joint loss optimization techniques for multi-output learning problems,
4//! where multiple output losses are combined using various strategies to optimize all outputs
5//! simultaneously rather than independently.
6//!
7//! ## Key Features
8//!
9//! - **Multiple Loss Functions**: Support for MSE, MAE, Huber, Cross-entropy, and Hinge losses
10//! - **Flexible Loss Combination**: Sum, weighted sum, max, geometric mean, and adaptive strategies
11//! - **Gradient-Based Optimization**: Efficient gradient computation for joint loss minimization
12//! - **Regularization**: L2 regularization to prevent overfitting
13//! - **Configurable Training**: Customizable learning rate, iterations, and convergence criteria
14
15// Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
16use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
17use scirs2_core::random::thread_rng;
18use scirs2_core::random::RandNormal;
19use sklears_core::{
20    error::{Result as SklResult, SklearsError},
21    traits::{Estimator, Fit, Predict, Untrained},
22    types::Float,
23};
24
25/// Loss function types for joint optimization
26#[derive(Debug, Clone, PartialEq)]
27pub enum LossFunction {
28    /// Mean Squared Error
29    MSE,
30    /// Mean Absolute Error
31    MAE,
32    /// Huber Loss with configurable delta
33    Huber(Float),
34    /// Cross-entropy loss
35    CrossEntropy,
36    /// Hinge loss
37    Hinge,
38    /// Custom loss function
39    Custom(String),
40}
41
42/// Loss combination strategies for joint optimization
43#[derive(Debug, Clone, PartialEq)]
44pub enum LossCombination {
45    /// Simple sum of individual losses
46    Sum,
47    /// Weighted sum of individual losses
48    WeightedSum(Vec<Float>),
49    /// Maximum of individual losses
50    Max,
51    /// Geometric mean of individual losses
52    GeometricMean,
53    /// Adaptive weighting based on loss magnitudes
54    Adaptive,
55}
56
57/// Joint Loss Optimization configuration
58#[derive(Debug, Clone)]
59pub struct JointLossConfig {
60    /// Loss function for each output
61    pub output_losses: Vec<LossFunction>,
62    /// Loss combination strategy
63    pub combination: LossCombination,
64    /// Regularization strength
65    pub regularization: Float,
66    /// Maximum number of iterations
67    pub max_iter: usize,
68    /// Convergence tolerance
69    pub tol: Float,
70    /// Learning rate
71    pub learning_rate: Float,
72    /// Random state for reproducibility
73    pub random_state: Option<u64>,
74}
75
76impl Default for JointLossConfig {
77    fn default() -> Self {
78        Self {
79            output_losses: vec![LossFunction::MSE],
80            combination: LossCombination::Sum,
81            regularization: 0.01,
82            max_iter: 1000,
83            tol: 1e-6,
84            learning_rate: 0.01,
85            random_state: None,
86        }
87    }
88}
89
90/// Joint Loss Optimizer for multi-output learning
91#[derive(Debug, Clone)]
92pub struct JointLossOptimizer<S = Untrained> {
93    state: S,
94    config: JointLossConfig,
95}
96
97/// Trained state for Joint Loss Optimizer
98#[derive(Debug, Clone)]
99pub struct JointLossOptimizerTrained {
100    /// Model weights
101    pub weights: Array2<Float>,
102    /// Bias terms
103    pub bias: Array1<Float>,
104    /// Number of features
105    pub n_features: usize,
106    /// Number of outputs
107    pub n_outputs: usize,
108    /// Training history
109    pub loss_history: Vec<Float>,
110    /// Configuration used for training
111    pub config: JointLossConfig,
112}
113
114impl JointLossOptimizer<Untrained> {
115    /// Create a new Joint Loss Optimizer
116    pub fn new() -> Self {
117        Self {
118            state: Untrained,
119            config: JointLossConfig::default(),
120        }
121    }
122
123    /// Set the configuration
124    pub fn config(mut self, config: JointLossConfig) -> Self {
125        self.config = config;
126        self
127    }
128
129    /// Set the output loss functions
130    pub fn output_losses(mut self, losses: Vec<LossFunction>) -> Self {
131        self.config.output_losses = losses;
132        self
133    }
134
135    /// Set the loss combination strategy
136    pub fn combination(mut self, combination: LossCombination) -> Self {
137        self.config.combination = combination;
138        self
139    }
140
141    /// Set the regularization strength
142    pub fn regularization(mut self, regularization: Float) -> Self {
143        self.config.regularization = regularization;
144        self
145    }
146
147    /// Set the maximum number of iterations
148    pub fn max_iter(mut self, max_iter: usize) -> Self {
149        self.config.max_iter = max_iter;
150        self
151    }
152
153    /// Set the convergence tolerance
154    pub fn tol(mut self, tol: Float) -> Self {
155        self.config.tol = tol;
156        self
157    }
158
159    /// Set the learning rate
160    pub fn learning_rate(mut self, learning_rate: Float) -> Self {
161        self.config.learning_rate = learning_rate;
162        self
163    }
164
165    /// Set the random state
166    pub fn random_state(mut self, random_state: Option<u64>) -> Self {
167        self.config.random_state = random_state;
168        self
169    }
170}
171
172impl Default for JointLossOptimizer<Untrained> {
173    fn default() -> Self {
174        Self::new()
175    }
176}
177
178impl Estimator for JointLossOptimizer<Untrained> {
179    type Config = JointLossConfig;
180    type Error = SklearsError;
181    type Float = Float;
182
183    fn config(&self) -> &Self::Config {
184        &self.config
185    }
186}
187
188impl Fit<ArrayView2<'_, Float>, ArrayView2<'_, Float>> for JointLossOptimizer<Untrained> {
189    type Fitted = JointLossOptimizer<JointLossOptimizerTrained>;
190
191    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView2<'_, Float>) -> SklResult<Self::Fitted> {
192        let (n_samples, n_features) = X.dim();
193        let (y_samples, n_outputs) = y.dim();
194
195        if n_samples != y_samples {
196            return Err(SklearsError::InvalidInput(
197                "X and y must have the same number of samples".to_string(),
198            ));
199        }
200
201        if n_outputs != self.config.output_losses.len() {
202            return Err(SklearsError::InvalidInput(format!(
203                "Number of outputs ({}) must match number of loss functions ({})",
204                n_outputs,
205                self.config.output_losses.len()
206            )));
207        }
208
209        let mut rng = thread_rng();
210
211        // Initialize weights using Xavier initialization
212        let std_dev = (2.0 / (n_features + n_outputs) as Float).sqrt();
213        let normal_dist = RandNormal::new(0.0, std_dev).unwrap();
214        let mut weights = Array2::<Float>::zeros((n_features, n_outputs));
215        for i in 0..n_features {
216            for j in 0..n_outputs {
217                weights[[i, j]] = rng.sample(normal_dist);
218            }
219        }
220        let mut bias = Array1::<Float>::zeros(n_outputs);
221
222        let mut loss_history = Vec::new();
223        let mut prev_loss = Float::INFINITY;
224
225        for iteration in 0..self.config.max_iter {
226            // Forward pass
227            let predictions = X.dot(&weights) + &bias;
228
229            // Compute joint loss
230            let joint_loss = self.compute_joint_loss(&predictions, y)?;
231            loss_history.push(joint_loss);
232
233            // Check convergence
234            if (prev_loss - joint_loss).abs() < self.config.tol {
235                break;
236            }
237            prev_loss = joint_loss;
238
239            // Compute gradients
240            let (weight_gradients, bias_gradients) = self.compute_gradients(X, y, &predictions)?;
241
242            // Update weights and bias
243            weights = weights - self.config.learning_rate * weight_gradients;
244            bias = bias - self.config.learning_rate * bias_gradients;
245
246            // Apply regularization
247            if self.config.regularization > 0.0 {
248                weights *= 1.0 - self.config.regularization * self.config.learning_rate;
249            }
250        }
251
252        Ok(JointLossOptimizer {
253            state: JointLossOptimizerTrained {
254                weights,
255                bias,
256                n_features,
257                n_outputs,
258                loss_history,
259                config: self.config.clone(),
260            },
261            config: self.config,
262        })
263    }
264}
265
266impl JointLossOptimizer<Untrained> {
267    /// Compute joint loss based on the combination strategy
268    fn compute_joint_loss(
269        &self,
270        predictions: &Array2<Float>,
271        y: &ArrayView2<'_, Float>,
272    ) -> SklResult<Float> {
273        let mut individual_losses = Vec::new();
274
275        for (i, loss_fn) in self.config.output_losses.iter().enumerate() {
276            let pred_col = predictions.column(i);
277            let y_col = y.column(i);
278            let loss = self.compute_individual_loss(loss_fn, &pred_col, &y_col)?;
279            individual_losses.push(loss);
280        }
281
282        let joint_loss = match &self.config.combination {
283            LossCombination::Sum => individual_losses.iter().sum(),
284            LossCombination::WeightedSum(weights) => {
285                if weights.len() != individual_losses.len() {
286                    return Err(SklearsError::InvalidInput(
287                        "Weight vector length must match number of outputs".to_string(),
288                    ));
289                }
290                individual_losses
291                    .iter()
292                    .zip(weights.iter())
293                    .map(|(loss, weight)| loss * weight)
294                    .sum()
295            }
296            LossCombination::Max => individual_losses.iter().cloned().fold(0.0, Float::max),
297            LossCombination::GeometricMean => {
298                let product: Float = individual_losses.iter().product();
299                product.powf(1.0 / individual_losses.len() as Float)
300            }
301            LossCombination::Adaptive => {
302                // Adaptive weighting based on loss magnitudes
303                let total_loss: Float = individual_losses.iter().sum();
304                if total_loss > 0.0 {
305                    let weights: Vec<Float> = individual_losses
306                        .iter()
307                        .map(|&loss| loss / total_loss)
308                        .collect();
309                    individual_losses
310                        .iter()
311                        .zip(weights.iter())
312                        .map(|(loss, weight)| loss * weight)
313                        .sum()
314                } else {
315                    0.0
316                }
317            }
318        };
319
320        Ok(joint_loss)
321    }
322
323    /// Compute individual loss for a specific output
324    fn compute_individual_loss(
325        &self,
326        loss_fn: &LossFunction,
327        predictions: &ArrayView1<'_, Float>,
328        y: &ArrayView1<'_, Float>,
329    ) -> SklResult<Float> {
330        match loss_fn {
331            LossFunction::MSE => {
332                let diff = predictions - y;
333                Ok(diff.mapv(|x| x * x).mean().unwrap_or(0.0))
334            }
335            LossFunction::MAE => {
336                let diff = predictions - y;
337                Ok(diff.mapv(|x| x.abs()).mean().unwrap_or(0.0))
338            }
339            LossFunction::Huber(delta) => {
340                let diff = predictions - y;
341                let huber_loss = diff.mapv(|x| {
342                    if x.abs() <= *delta {
343                        0.5 * x * x
344                    } else {
345                        delta * x.abs() - 0.5 * delta * delta
346                    }
347                });
348                Ok(huber_loss.mean().unwrap_or(0.0))
349            }
350            LossFunction::CrossEntropy => {
351                // Assuming binary classification with sigmoid activation
352                let epsilon = 1e-15;
353                let clipped_preds = predictions.mapv(|x| x.max(epsilon).min(1.0 - epsilon));
354                let loss = y
355                    .iter()
356                    .zip(clipped_preds.iter())
357                    .map(|(y_true, y_pred)| {
358                        -(y_true * y_pred.ln() + (1.0 - y_true) * (1.0 - y_pred).ln())
359                    })
360                    .sum::<Float>()
361                    / y.len() as Float;
362                Ok(loss)
363            }
364            LossFunction::Hinge => {
365                let loss = predictions
366                    .iter()
367                    .zip(y.iter())
368                    .map(|(pred, true_val)| {
369                        let margin = true_val * pred;
370                        if margin < 1.0 {
371                            1.0 - margin
372                        } else {
373                            0.0
374                        }
375                    })
376                    .sum::<Float>()
377                    / y.len() as Float;
378                Ok(loss)
379            }
380            LossFunction::Custom(_) => Err(SklearsError::InvalidInput(
381                "Custom loss functions are not yet implemented".to_string(),
382            )),
383        }
384    }
385
386    /// Compute gradients for weights and bias
387    fn compute_gradients(
388        &self,
389        X: &ArrayView2<'_, Float>,
390        y: &ArrayView2<'_, Float>,
391        predictions: &Array2<Float>,
392    ) -> SklResult<(Array2<Float>, Array1<Float>)> {
393        let (n_samples, n_features) = X.dim();
394        let n_outputs = y.ncols();
395
396        let mut weight_gradients = Array2::<Float>::zeros((n_features, n_outputs));
397        let mut bias_gradients = Array1::<Float>::zeros(n_outputs);
398
399        for (i, loss_fn) in self.config.output_losses.iter().enumerate() {
400            let pred_col = predictions.column(i);
401            let y_col = y.column(i);
402
403            // Compute gradient for this output
404            let output_gradient = self.compute_output_gradient(loss_fn, &pred_col, &y_col)?;
405
406            // Update weight gradients
407            for j in 0..n_features {
408                weight_gradients[(j, i)] = X.column(j).dot(&output_gradient) / n_samples as Float;
409            }
410
411            // Update bias gradients
412            bias_gradients[i] = output_gradient.mean().unwrap_or(0.0);
413        }
414
415        Ok((weight_gradients, bias_gradients))
416    }
417
418    /// Compute gradient for a specific output
419    fn compute_output_gradient(
420        &self,
421        loss_fn: &LossFunction,
422        predictions: &ArrayView1<'_, Float>,
423        y: &ArrayView1<'_, Float>,
424    ) -> SklResult<Array1<Float>> {
425        let gradient = match loss_fn {
426            LossFunction::MSE => 2.0 * (predictions - y),
427            LossFunction::MAE => (predictions - y).mapv(|x| {
428                if x > 0.0 {
429                    1.0
430                } else if x < 0.0 {
431                    -1.0
432                } else {
433                    0.0
434                }
435            }),
436            LossFunction::Huber(delta) => {
437                let diff = predictions - y;
438                diff.mapv(|x| {
439                    if x.abs() <= *delta {
440                        x
441                    } else {
442                        delta * x.signum()
443                    }
444                })
445            }
446            LossFunction::CrossEntropy => {
447                // Gradient for binary cross-entropy with sigmoid
448                let epsilon = 1e-15;
449                let clipped_preds = predictions.mapv(|x| x.max(epsilon).min(1.0 - epsilon));
450                &clipped_preds - y
451            }
452            LossFunction::Hinge => predictions
453                .iter()
454                .zip(y.iter())
455                .map(|(pred, true_val)| {
456                    let margin = true_val * pred;
457                    if margin < 1.0 {
458                        -true_val
459                    } else {
460                        0.0
461                    }
462                })
463                .collect::<Array1<Float>>(),
464            LossFunction::Custom(_) => {
465                return Err(SklearsError::InvalidInput(
466                    "Custom loss functions are not yet implemented".to_string(),
467                ));
468            }
469        };
470
471        Ok(gradient)
472    }
473}
474
475impl Predict<ArrayView2<'_, Float>, Array2<Float>>
476    for JointLossOptimizer<JointLossOptimizerTrained>
477{
478    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
479        let (n_samples, n_features) = X.dim();
480
481        if n_features != self.state.n_features {
482            return Err(SklearsError::InvalidInput(format!(
483                "Expected {} features, got {}",
484                self.state.n_features, n_features
485            )));
486        }
487
488        let predictions = X.dot(&self.state.weights) + &self.state.bias;
489        Ok(predictions)
490    }
491}
492
493impl Estimator for JointLossOptimizer<JointLossOptimizerTrained> {
494    type Config = JointLossConfig;
495    type Error = SklearsError;
496    type Float = Float;
497
498    fn config(&self) -> &Self::Config {
499        &self.state.config
500    }
501}
502
503impl JointLossOptimizer<JointLossOptimizerTrained> {
504    /// Get the training loss history
505    pub fn loss_history(&self) -> &[Float] {
506        &self.state.loss_history
507    }
508
509    /// Get the model weights
510    pub fn weights(&self) -> &Array2<Float> {
511        &self.state.weights
512    }
513
514    /// Get the bias terms
515    pub fn bias(&self) -> &Array1<Float> {
516        &self.state.bias
517    }
518
519    /// Get the number of features
520    pub fn n_features(&self) -> usize {
521        self.state.n_features
522    }
523
524    /// Get the number of outputs
525    pub fn n_outputs(&self) -> usize {
526        self.state.n_outputs
527    }
528}