sklears_multioutput/optimization/
joint_loss_optimization.rs1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
17use scirs2_core::random::thread_rng;
18use scirs2_core::random::RandNormal;
19use sklears_core::{
20 error::{Result as SklResult, SklearsError},
21 traits::{Estimator, Fit, Predict, Untrained},
22 types::Float,
23};
24
25#[derive(Debug, Clone, PartialEq)]
27pub enum LossFunction {
28 MSE,
30 MAE,
32 Huber(Float),
34 CrossEntropy,
36 Hinge,
38 Custom(String),
40}
41
42#[derive(Debug, Clone, PartialEq)]
44pub enum LossCombination {
45 Sum,
47 WeightedSum(Vec<Float>),
49 Max,
51 GeometricMean,
53 Adaptive,
55}
56
57#[derive(Debug, Clone)]
59pub struct JointLossConfig {
60 pub output_losses: Vec<LossFunction>,
62 pub combination: LossCombination,
64 pub regularization: Float,
66 pub max_iter: usize,
68 pub tol: Float,
70 pub learning_rate: Float,
72 pub random_state: Option<u64>,
74}
75
76impl Default for JointLossConfig {
77 fn default() -> Self {
78 Self {
79 output_losses: vec![LossFunction::MSE],
80 combination: LossCombination::Sum,
81 regularization: 0.01,
82 max_iter: 1000,
83 tol: 1e-6,
84 learning_rate: 0.01,
85 random_state: None,
86 }
87 }
88}
89
90#[derive(Debug, Clone)]
92pub struct JointLossOptimizer<S = Untrained> {
93 state: S,
94 config: JointLossConfig,
95}
96
97#[derive(Debug, Clone)]
99pub struct JointLossOptimizerTrained {
100 pub weights: Array2<Float>,
102 pub bias: Array1<Float>,
104 pub n_features: usize,
106 pub n_outputs: usize,
108 pub loss_history: Vec<Float>,
110 pub config: JointLossConfig,
112}
113
114impl JointLossOptimizer<Untrained> {
115 pub fn new() -> Self {
117 Self {
118 state: Untrained,
119 config: JointLossConfig::default(),
120 }
121 }
122
123 pub fn config(mut self, config: JointLossConfig) -> Self {
125 self.config = config;
126 self
127 }
128
129 pub fn output_losses(mut self, losses: Vec<LossFunction>) -> Self {
131 self.config.output_losses = losses;
132 self
133 }
134
135 pub fn combination(mut self, combination: LossCombination) -> Self {
137 self.config.combination = combination;
138 self
139 }
140
141 pub fn regularization(mut self, regularization: Float) -> Self {
143 self.config.regularization = regularization;
144 self
145 }
146
147 pub fn max_iter(mut self, max_iter: usize) -> Self {
149 self.config.max_iter = max_iter;
150 self
151 }
152
153 pub fn tol(mut self, tol: Float) -> Self {
155 self.config.tol = tol;
156 self
157 }
158
159 pub fn learning_rate(mut self, learning_rate: Float) -> Self {
161 self.config.learning_rate = learning_rate;
162 self
163 }
164
165 pub fn random_state(mut self, random_state: Option<u64>) -> Self {
167 self.config.random_state = random_state;
168 self
169 }
170}
171
172impl Default for JointLossOptimizer<Untrained> {
173 fn default() -> Self {
174 Self::new()
175 }
176}
177
178impl Estimator for JointLossOptimizer<Untrained> {
179 type Config = JointLossConfig;
180 type Error = SklearsError;
181 type Float = Float;
182
183 fn config(&self) -> &Self::Config {
184 &self.config
185 }
186}
187
188impl Fit<ArrayView2<'_, Float>, ArrayView2<'_, Float>> for JointLossOptimizer<Untrained> {
189 type Fitted = JointLossOptimizer<JointLossOptimizerTrained>;
190
191 fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView2<'_, Float>) -> SklResult<Self::Fitted> {
192 let (n_samples, n_features) = X.dim();
193 let (y_samples, n_outputs) = y.dim();
194
195 if n_samples != y_samples {
196 return Err(SklearsError::InvalidInput(
197 "X and y must have the same number of samples".to_string(),
198 ));
199 }
200
201 if n_outputs != self.config.output_losses.len() {
202 return Err(SklearsError::InvalidInput(format!(
203 "Number of outputs ({}) must match number of loss functions ({})",
204 n_outputs,
205 self.config.output_losses.len()
206 )));
207 }
208
209 let mut rng = thread_rng();
210
211 let std_dev = (2.0 / (n_features + n_outputs) as Float).sqrt();
213 let normal_dist = RandNormal::new(0.0, std_dev).unwrap();
214 let mut weights = Array2::<Float>::zeros((n_features, n_outputs));
215 for i in 0..n_features {
216 for j in 0..n_outputs {
217 weights[[i, j]] = rng.sample(normal_dist);
218 }
219 }
220 let mut bias = Array1::<Float>::zeros(n_outputs);
221
222 let mut loss_history = Vec::new();
223 let mut prev_loss = Float::INFINITY;
224
225 for iteration in 0..self.config.max_iter {
226 let predictions = X.dot(&weights) + &bias;
228
229 let joint_loss = self.compute_joint_loss(&predictions, y)?;
231 loss_history.push(joint_loss);
232
233 if (prev_loss - joint_loss).abs() < self.config.tol {
235 break;
236 }
237 prev_loss = joint_loss;
238
239 let (weight_gradients, bias_gradients) = self.compute_gradients(X, y, &predictions)?;
241
242 weights = weights - self.config.learning_rate * weight_gradients;
244 bias = bias - self.config.learning_rate * bias_gradients;
245
246 if self.config.regularization > 0.0 {
248 weights *= 1.0 - self.config.regularization * self.config.learning_rate;
249 }
250 }
251
252 Ok(JointLossOptimizer {
253 state: JointLossOptimizerTrained {
254 weights,
255 bias,
256 n_features,
257 n_outputs,
258 loss_history,
259 config: self.config.clone(),
260 },
261 config: self.config,
262 })
263 }
264}
265
266impl JointLossOptimizer<Untrained> {
267 fn compute_joint_loss(
269 &self,
270 predictions: &Array2<Float>,
271 y: &ArrayView2<'_, Float>,
272 ) -> SklResult<Float> {
273 let mut individual_losses = Vec::new();
274
275 for (i, loss_fn) in self.config.output_losses.iter().enumerate() {
276 let pred_col = predictions.column(i);
277 let y_col = y.column(i);
278 let loss = self.compute_individual_loss(loss_fn, &pred_col, &y_col)?;
279 individual_losses.push(loss);
280 }
281
282 let joint_loss = match &self.config.combination {
283 LossCombination::Sum => individual_losses.iter().sum(),
284 LossCombination::WeightedSum(weights) => {
285 if weights.len() != individual_losses.len() {
286 return Err(SklearsError::InvalidInput(
287 "Weight vector length must match number of outputs".to_string(),
288 ));
289 }
290 individual_losses
291 .iter()
292 .zip(weights.iter())
293 .map(|(loss, weight)| loss * weight)
294 .sum()
295 }
296 LossCombination::Max => individual_losses.iter().cloned().fold(0.0, Float::max),
297 LossCombination::GeometricMean => {
298 let product: Float = individual_losses.iter().product();
299 product.powf(1.0 / individual_losses.len() as Float)
300 }
301 LossCombination::Adaptive => {
302 let total_loss: Float = individual_losses.iter().sum();
304 if total_loss > 0.0 {
305 let weights: Vec<Float> = individual_losses
306 .iter()
307 .map(|&loss| loss / total_loss)
308 .collect();
309 individual_losses
310 .iter()
311 .zip(weights.iter())
312 .map(|(loss, weight)| loss * weight)
313 .sum()
314 } else {
315 0.0
316 }
317 }
318 };
319
320 Ok(joint_loss)
321 }
322
323 fn compute_individual_loss(
325 &self,
326 loss_fn: &LossFunction,
327 predictions: &ArrayView1<'_, Float>,
328 y: &ArrayView1<'_, Float>,
329 ) -> SklResult<Float> {
330 match loss_fn {
331 LossFunction::MSE => {
332 let diff = predictions - y;
333 Ok(diff.mapv(|x| x * x).mean().unwrap_or(0.0))
334 }
335 LossFunction::MAE => {
336 let diff = predictions - y;
337 Ok(diff.mapv(|x| x.abs()).mean().unwrap_or(0.0))
338 }
339 LossFunction::Huber(delta) => {
340 let diff = predictions - y;
341 let huber_loss = diff.mapv(|x| {
342 if x.abs() <= *delta {
343 0.5 * x * x
344 } else {
345 delta * x.abs() - 0.5 * delta * delta
346 }
347 });
348 Ok(huber_loss.mean().unwrap_or(0.0))
349 }
350 LossFunction::CrossEntropy => {
351 let epsilon = 1e-15;
353 let clipped_preds = predictions.mapv(|x| x.max(epsilon).min(1.0 - epsilon));
354 let loss = y
355 .iter()
356 .zip(clipped_preds.iter())
357 .map(|(y_true, y_pred)| {
358 -(y_true * y_pred.ln() + (1.0 - y_true) * (1.0 - y_pred).ln())
359 })
360 .sum::<Float>()
361 / y.len() as Float;
362 Ok(loss)
363 }
364 LossFunction::Hinge => {
365 let loss = predictions
366 .iter()
367 .zip(y.iter())
368 .map(|(pred, true_val)| {
369 let margin = true_val * pred;
370 if margin < 1.0 {
371 1.0 - margin
372 } else {
373 0.0
374 }
375 })
376 .sum::<Float>()
377 / y.len() as Float;
378 Ok(loss)
379 }
380 LossFunction::Custom(_) => Err(SklearsError::InvalidInput(
381 "Custom loss functions are not yet implemented".to_string(),
382 )),
383 }
384 }
385
386 fn compute_gradients(
388 &self,
389 X: &ArrayView2<'_, Float>,
390 y: &ArrayView2<'_, Float>,
391 predictions: &Array2<Float>,
392 ) -> SklResult<(Array2<Float>, Array1<Float>)> {
393 let (n_samples, n_features) = X.dim();
394 let n_outputs = y.ncols();
395
396 let mut weight_gradients = Array2::<Float>::zeros((n_features, n_outputs));
397 let mut bias_gradients = Array1::<Float>::zeros(n_outputs);
398
399 for (i, loss_fn) in self.config.output_losses.iter().enumerate() {
400 let pred_col = predictions.column(i);
401 let y_col = y.column(i);
402
403 let output_gradient = self.compute_output_gradient(loss_fn, &pred_col, &y_col)?;
405
406 for j in 0..n_features {
408 weight_gradients[(j, i)] = X.column(j).dot(&output_gradient) / n_samples as Float;
409 }
410
411 bias_gradients[i] = output_gradient.mean().unwrap_or(0.0);
413 }
414
415 Ok((weight_gradients, bias_gradients))
416 }
417
418 fn compute_output_gradient(
420 &self,
421 loss_fn: &LossFunction,
422 predictions: &ArrayView1<'_, Float>,
423 y: &ArrayView1<'_, Float>,
424 ) -> SklResult<Array1<Float>> {
425 let gradient = match loss_fn {
426 LossFunction::MSE => 2.0 * (predictions - y),
427 LossFunction::MAE => (predictions - y).mapv(|x| {
428 if x > 0.0 {
429 1.0
430 } else if x < 0.0 {
431 -1.0
432 } else {
433 0.0
434 }
435 }),
436 LossFunction::Huber(delta) => {
437 let diff = predictions - y;
438 diff.mapv(|x| {
439 if x.abs() <= *delta {
440 x
441 } else {
442 delta * x.signum()
443 }
444 })
445 }
446 LossFunction::CrossEntropy => {
447 let epsilon = 1e-15;
449 let clipped_preds = predictions.mapv(|x| x.max(epsilon).min(1.0 - epsilon));
450 &clipped_preds - y
451 }
452 LossFunction::Hinge => predictions
453 .iter()
454 .zip(y.iter())
455 .map(|(pred, true_val)| {
456 let margin = true_val * pred;
457 if margin < 1.0 {
458 -true_val
459 } else {
460 0.0
461 }
462 })
463 .collect::<Array1<Float>>(),
464 LossFunction::Custom(_) => {
465 return Err(SklearsError::InvalidInput(
466 "Custom loss functions are not yet implemented".to_string(),
467 ));
468 }
469 };
470
471 Ok(gradient)
472 }
473}
474
475impl Predict<ArrayView2<'_, Float>, Array2<Float>>
476 for JointLossOptimizer<JointLossOptimizerTrained>
477{
478 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
479 let (n_samples, n_features) = X.dim();
480
481 if n_features != self.state.n_features {
482 return Err(SklearsError::InvalidInput(format!(
483 "Expected {} features, got {}",
484 self.state.n_features, n_features
485 )));
486 }
487
488 let predictions = X.dot(&self.state.weights) + &self.state.bias;
489 Ok(predictions)
490 }
491}
492
493impl Estimator for JointLossOptimizer<JointLossOptimizerTrained> {
494 type Config = JointLossConfig;
495 type Error = SklearsError;
496 type Float = Float;
497
498 fn config(&self) -> &Self::Config {
499 &self.state.config
500 }
501}
502
503impl JointLossOptimizer<JointLossOptimizerTrained> {
504 pub fn loss_history(&self) -> &[Float] {
506 &self.state.loss_history
507 }
508
509 pub fn weights(&self) -> &Array2<Float> {
511 &self.state.weights
512 }
513
514 pub fn bias(&self) -> &Array1<Float> {
516 &self.state.bias
517 }
518
519 pub fn n_features(&self) -> usize {
521 self.state.n_features
522 }
523
524 pub fn n_outputs(&self) -> usize {
526 self.state.n_outputs
527 }
528}