sklears_kernel_approximation/kernel_ridge_regression/
robust_regression.rs

1//! Robust Kernel Ridge Regression Implementation
2//!
3//! This module implements robust variants of kernel ridge regression that are resistant
4//! to outliers and noise in the data. Multiple robust loss functions are supported,
5//! and the optimization is performed using iteratively reweighted least squares (IRLS).
6
7use crate::{
8    FastfoodTransform, Nystroem, RBFSampler, StructuredRandomFeatures, Trained, Untrained,
9};
10use scirs2_core::ndarray::ndarray_linalg::solve::Solve;
11use scirs2_core::ndarray::ndarray_linalg::SVD;
12use scirs2_core::ndarray::{Array1, Array2};
13use sklears_core::error::{Result, SklearsError};
14use sklears_core::prelude::{Estimator, Fit, Float, Predict};
15use std::marker::PhantomData;
16
17use super::core_types::*;
18
19/// Robust kernel ridge regression
20///
21/// This implements robust variants of kernel ridge regression that are resistant
22/// to outliers and noise in the data. Multiple robust loss functions are supported.
23///
24/// # Parameters
25///
26/// * `approximation_method` - Method for kernel approximation
27/// * `alpha` - Regularization strength
28/// * `robust_loss` - Robust loss function to use
29/// * `solver` - Method for solving the optimization problem
30/// * `max_iter` - Maximum number of iterations for robust optimization
31/// * `tolerance` - Convergence tolerance
32/// * `random_state` - Random seed for reproducibility
33///
34/// # Examples
35///
36/// ```rust,ignore
37/// use sklears_kernel_approximation::kernel_ridge_regression::{
38#[derive(Debug, Clone)]
39pub struct RobustKernelRidgeRegression<State = Untrained> {
40    pub approximation_method: ApproximationMethod,
41    pub alpha: Float,
42    pub robust_loss: RobustLoss,
43    pub solver: Solver,
44    pub max_iter: usize,
45    pub tolerance: Float,
46    pub random_state: Option<u64>,
47
48    // Fitted parameters
49    weights_: Option<Array1<Float>>,
50    feature_transformer_: Option<FeatureTransformer>,
51    sample_weights_: Option<Array1<Float>>, // Adaptive weights for robustness
52
53    _state: PhantomData<State>,
54}
55
56/// Robust loss functions for kernel ridge regression
57#[derive(Debug, Clone)]
58pub enum RobustLoss {
59    /// Huber loss - quadratic for small residuals, linear for large ones
60    Huber { delta: Float },
61    /// Epsilon-insensitive loss (used in SVR)
62    EpsilonInsensitive { epsilon: Float },
63    /// Quantile loss for quantile regression
64    Quantile { tau: Float },
65    /// Tukey's biweight loss
66    Tukey { c: Float },
67    /// Cauchy loss
68    Cauchy { sigma: Float },
69    /// Logistic loss
70    Logistic { scale: Float },
71    /// Fair loss
72    Fair { c: Float },
73    /// Welsch loss
74    Welsch { c: Float },
75    /// Custom robust loss function
76    Custom {
77        loss_fn: fn(Float) -> Float,
78        weight_fn: fn(Float) -> Float,
79    },
80}
81
82impl Default for RobustLoss {
83    fn default() -> Self {
84        Self::Huber { delta: 1.0 }
85    }
86}
87
88impl RobustLoss {
89    /// Compute the loss for a given residual
90    pub fn loss(&self, residual: Float) -> Float {
91        let abs_r = residual.abs();
92        match self {
93            RobustLoss::Huber { delta } => {
94                if abs_r <= *delta {
95                    0.5 * residual * residual
96                } else {
97                    delta * (abs_r - 0.5 * delta)
98                }
99            }
100            RobustLoss::EpsilonInsensitive { epsilon } => (abs_r - epsilon).max(0.0),
101            RobustLoss::Quantile { tau } => {
102                if residual >= 0.0 {
103                    tau * residual
104                } else {
105                    (tau - 1.0) * residual
106                }
107            }
108            RobustLoss::Tukey { c } => {
109                if abs_r <= *c {
110                    let r_norm = residual / c;
111                    (c * c / 6.0) * (1.0 - (1.0 - r_norm * r_norm).powi(3))
112                } else {
113                    c * c / 6.0
114                }
115            }
116            RobustLoss::Cauchy { sigma } => {
117                (sigma * sigma / 2.0) * ((1.0 + (residual / sigma).powi(2)).ln())
118            }
119            RobustLoss::Logistic { scale } => scale * (1.0 + (-abs_r / scale).exp()).ln(),
120            RobustLoss::Fair { c } => c * (abs_r / c - (1.0 + abs_r / c).ln()),
121            RobustLoss::Welsch { c } => (c * c / 2.0) * (1.0 - (-((residual / c).powi(2))).exp()),
122            RobustLoss::Custom { loss_fn, .. } => loss_fn(residual),
123        }
124    }
125
126    /// Compute the weight for iteratively reweighted least squares
127    pub fn weight(&self, residual: Float) -> Float {
128        let abs_r = residual.abs();
129        if abs_r < 1e-10 {
130            return 1.0; // Avoid division by zero
131        }
132
133        match self {
134            RobustLoss::Huber { delta } => {
135                if abs_r <= *delta {
136                    1.0
137                } else {
138                    delta / abs_r
139                }
140            }
141            RobustLoss::EpsilonInsensitive { epsilon } => {
142                if abs_r <= *epsilon {
143                    0.0
144                } else {
145                    1.0
146                }
147            }
148            RobustLoss::Quantile { tau } => {
149                // For quantile regression, weights are constant
150                if residual >= 0.0 {
151                    *tau
152                } else {
153                    1.0 - tau
154                }
155            }
156            RobustLoss::Tukey { c } => {
157                if abs_r <= *c {
158                    let r_norm = residual / c;
159                    (1.0 - r_norm * r_norm).powi(2)
160                } else {
161                    0.0
162                }
163            }
164            RobustLoss::Cauchy { sigma } => 1.0 / (1.0 + (residual / sigma).powi(2)),
165            RobustLoss::Logistic { scale } => {
166                let exp_term = (-abs_r / scale).exp();
167                exp_term / (1.0 + exp_term)
168            }
169            RobustLoss::Fair { c } => 1.0 / (1.0 + abs_r / c),
170            RobustLoss::Welsch { c } => (-((residual / c).powi(2))).exp(),
171            RobustLoss::Custom { weight_fn, .. } => weight_fn(residual),
172        }
173    }
174}
175
176impl RobustKernelRidgeRegression<Untrained> {
177    /// Create a new robust kernel ridge regression model
178    pub fn new(approximation_method: ApproximationMethod) -> Self {
179        Self {
180            approximation_method,
181            alpha: 1.0,
182            robust_loss: RobustLoss::default(),
183            solver: Solver::Direct,
184            max_iter: 100,
185            tolerance: 1e-6,
186            random_state: None,
187            weights_: None,
188            feature_transformer_: None,
189            sample_weights_: None,
190            _state: PhantomData,
191        }
192    }
193
194    /// Set regularization parameter
195    pub fn alpha(mut self, alpha: Float) -> Self {
196        self.alpha = alpha;
197        self
198    }
199
200    /// Set robust loss function
201    pub fn robust_loss(mut self, robust_loss: RobustLoss) -> Self {
202        self.robust_loss = robust_loss;
203        self
204    }
205
206    /// Set solver method
207    pub fn solver(mut self, solver: Solver) -> Self {
208        self.solver = solver;
209        self
210    }
211
212    /// Set maximum iterations for robust optimization
213    pub fn max_iter(mut self, max_iter: usize) -> Self {
214        self.max_iter = max_iter;
215        self
216    }
217
218    /// Set convergence tolerance
219    pub fn tolerance(mut self, tolerance: Float) -> Self {
220        self.tolerance = tolerance;
221        self
222    }
223
224    /// Set random state for reproducibility
225    pub fn random_state(mut self, seed: u64) -> Self {
226        self.random_state = Some(seed);
227        self
228    }
229}
230
231impl Estimator for RobustKernelRidgeRegression<Untrained> {
232    type Config = ();
233    type Error = SklearsError;
234    type Float = Float;
235
236    fn config(&self) -> &Self::Config {
237        &()
238    }
239}
240
241impl Fit<Array2<Float>, Array1<Float>> for RobustKernelRidgeRegression<Untrained> {
242    type Fitted = RobustKernelRidgeRegression<Trained>;
243
244    fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
245        if x.nrows() != y.len() {
246            return Err(SklearsError::InvalidInput(
247                "Number of samples must match".to_string(),
248            ));
249        }
250
251        // Fit the feature transformer
252        let feature_transformer = self.fit_feature_transformer(x)?;
253        let x_transformed = feature_transformer.transform(x)?;
254
255        // Solve robust regression using iteratively reweighted least squares (IRLS)
256        let (weights, sample_weights) = self.solve_robust_regression(&x_transformed, y)?;
257
258        Ok(RobustKernelRidgeRegression {
259            approximation_method: self.approximation_method,
260            alpha: self.alpha,
261            robust_loss: self.robust_loss,
262            solver: self.solver,
263            max_iter: self.max_iter,
264            tolerance: self.tolerance,
265            random_state: self.random_state,
266            weights_: Some(weights),
267            feature_transformer_: Some(feature_transformer),
268            sample_weights_: Some(sample_weights),
269            _state: PhantomData,
270        })
271    }
272}
273
274impl RobustKernelRidgeRegression<Untrained> {
275    /// Fit the feature transformer based on the approximation method
276    fn fit_feature_transformer(&self, x: &Array2<Float>) -> Result<FeatureTransformer> {
277        match &self.approximation_method {
278            ApproximationMethod::Nystroem {
279                kernel,
280                n_components,
281                sampling_strategy,
282            } => {
283                let mut nystroem = Nystroem::new(kernel.clone(), *n_components)
284                    .sampling_strategy(sampling_strategy.clone());
285                if let Some(seed) = self.random_state {
286                    nystroem = nystroem.random_state(seed);
287                }
288                let fitted = nystroem.fit(x, &())?;
289                Ok(FeatureTransformer::Nystroem(fitted))
290            }
291            ApproximationMethod::RandomFourierFeatures {
292                n_components,
293                gamma,
294            } => {
295                let mut rff = RBFSampler::new(*n_components).gamma(*gamma);
296                if let Some(seed) = self.random_state {
297                    rff = rff.random_state(seed);
298                }
299                let fitted = rff.fit(x, &())?;
300                Ok(FeatureTransformer::RBFSampler(fitted))
301            }
302            ApproximationMethod::StructuredRandomFeatures {
303                n_components,
304                gamma,
305            } => {
306                let mut srf = StructuredRandomFeatures::new(*n_components).gamma(*gamma);
307                if let Some(seed) = self.random_state {
308                    srf = srf.random_state(seed);
309                }
310                let fitted = srf.fit(x, &())?;
311                Ok(FeatureTransformer::StructuredRFF(fitted))
312            }
313            ApproximationMethod::Fastfood {
314                n_components,
315                gamma,
316            } => {
317                let mut fastfood = FastfoodTransform::new(*n_components).gamma(*gamma);
318                if let Some(seed) = self.random_state {
319                    fastfood = fastfood.random_state(seed);
320                }
321                let fitted = fastfood.fit(x, &())?;
322                Ok(FeatureTransformer::Fastfood(fitted))
323            }
324        }
325    }
326
327    /// Solve robust regression using iteratively reweighted least squares
328    fn solve_robust_regression(
329        &self,
330        x: &Array2<Float>,
331        y: &Array1<Float>,
332    ) -> Result<(Array1<Float>, Array1<Float>)> {
333        let n_samples = x.nrows();
334        let n_features = x.ncols();
335
336        // Initialize with ordinary least squares solution
337        let x_f64 = Array2::from_shape_fn(x.dim(), |(i, j)| x[[i, j]] as f64);
338        let y_f64 = Array1::from_vec(y.iter().map(|&val| val as f64).collect());
339
340        let xtx = x_f64.t().dot(&x_f64);
341        let regularized_xtx = xtx + Array2::<f64>::eye(n_features) * (self.alpha as f64);
342        let xty = x_f64.t().dot(&y_f64);
343        let mut weights_f64 =
344            regularized_xtx
345                .solve(&xty)
346                .map_err(|e| SklearsError::InvalidParameter {
347                    name: "regularization".to_string(),
348                    reason: format!("Initial linear system solving failed: {:?}", e),
349                })?;
350
351        let mut sample_weights = Array1::ones(n_samples);
352        let mut prev_weights = weights_f64.clone();
353
354        // Iteratively reweighted least squares
355        for _iter in 0..self.max_iter {
356            // Compute residuals
357            let predictions = x_f64.dot(&weights_f64);
358            let residuals = &y_f64 - &predictions;
359
360            // Update sample weights based on residuals
361            for (i, &residual) in residuals.iter().enumerate() {
362                sample_weights[i] = self.robust_loss.weight(residual as Float) as f64;
363            }
364
365            // Solve weighted least squares
366            let mut weighted_xtx = Array2::zeros((n_features, n_features));
367            let mut weighted_xty = Array1::zeros(n_features);
368
369            for i in 0..n_samples {
370                let weight = sample_weights[i];
371                let x_row = x_f64.row(i);
372
373                // X^T W X
374                for j in 0..n_features {
375                    for k in 0..n_features {
376                        weighted_xtx[[j, k]] += weight * x_row[j] * x_row[k];
377                    }
378                }
379
380                // X^T W y
381                for j in 0..n_features {
382                    weighted_xty[j] += weight * x_row[j] * y_f64[i];
383                }
384            }
385
386            // Add regularization
387            weighted_xtx += &(Array2::eye(n_features) * (self.alpha as f64));
388
389            // Solve the weighted system
390            weights_f64 = match self.solver {
391                Solver::Direct => weighted_xtx.solve(&weighted_xty).map_err(|e| {
392                    SklearsError::InvalidParameter {
393                        name: "weighted_system".to_string(),
394                        reason: format!("Weighted linear system solving failed: {:?}", e),
395                    }
396                })?,
397                Solver::SVD => {
398                    let (u, s, vt) = weighted_xtx.svd(true, true).map_err(|e| {
399                        SklearsError::InvalidParameter {
400                            name: "svd".to_string(),
401                            reason: format!("SVD decomposition failed: {:?}", e),
402                        }
403                    })?;
404                    let u = u.unwrap();
405                    let vt = vt.unwrap();
406                    let ut_b = u.t().dot(&weighted_xty);
407                    let s_inv = s.mapv(|x| if x > 1e-10 { 1.0 / x } else { 0.0 });
408                    let y_svd = ut_b * s_inv;
409                    vt.t().dot(&y_svd)
410                }
411                Solver::ConjugateGradient { max_iter, tol } => {
412                    self.solve_cg_weighted(&weighted_xtx, &weighted_xty, max_iter, tol as f64)?
413                }
414            };
415
416            // Check convergence
417            let weight_change = (&weights_f64 - &prev_weights).mapv(|x| x.abs()).sum();
418            if weight_change < (self.tolerance as f64) {
419                break;
420            }
421
422            prev_weights = weights_f64.clone();
423        }
424
425        // Convert back to Float
426        let weights = Array1::from_vec(weights_f64.iter().map(|&val| val as Float).collect());
427        let sample_weights_float =
428            Array1::from_vec(sample_weights.iter().map(|&val| val as Float).collect());
429
430        Ok((weights, sample_weights_float))
431    }
432
433    /// Solve using conjugate gradient method
434    fn solve_cg_weighted(
435        &self,
436        a: &Array2<f64>,
437        b: &Array1<f64>,
438        max_iter: usize,
439        tol: f64,
440    ) -> Result<Array1<f64>> {
441        let n = b.len();
442        let mut x = Array1::zeros(n);
443        let mut r = b - &a.dot(&x);
444        let mut p = r.clone();
445        let mut rsold = r.dot(&r);
446
447        for _iter in 0..max_iter {
448            let ap = a.dot(&p);
449            let alpha = rsold / p.dot(&ap);
450
451            x = x + &p * alpha;
452            r = r - &ap * alpha;
453
454            let rsnew = r.dot(&r);
455
456            if rsnew.sqrt() < tol {
457                break;
458            }
459
460            let beta = rsnew / rsold;
461            p = &r + &p * beta;
462            rsold = rsnew;
463        }
464
465        Ok(x)
466    }
467}
468
469impl Predict<Array2<Float>, Array1<Float>> for RobustKernelRidgeRegression<Trained> {
470    fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
471        let feature_transformer =
472            self.feature_transformer_
473                .as_ref()
474                .ok_or_else(|| SklearsError::NotFitted {
475                    operation: "predict".to_string(),
476                })?;
477
478        let weights = self
479            .weights_
480            .as_ref()
481            .ok_or_else(|| SklearsError::NotFitted {
482                operation: "predict".to_string(),
483            })?;
484
485        let x_transformed = feature_transformer.transform(x)?;
486        let predictions = x_transformed.dot(weights);
487
488        Ok(predictions)
489    }
490}
491
492impl RobustKernelRidgeRegression<Trained> {
493    /// Get the fitted weights
494    pub fn weights(&self) -> Option<&Array1<Float>> {
495        self.weights_.as_ref()
496    }
497
498    /// Get the sample weights from robust fitting
499    pub fn sample_weights(&self) -> Option<&Array1<Float>> {
500        self.sample_weights_.as_ref()
501    }
502
503    /// Compute robust residuals and their weights
504    pub fn robust_residuals(
505        &self,
506        x: &Array2<Float>,
507        y: &Array1<Float>,
508    ) -> Result<(Array1<Float>, Array1<Float>)> {
509        let predictions = self.predict(x)?;
510        let residuals = y - &predictions;
511
512        let mut weights = Array1::zeros(residuals.len());
513        for (i, &residual) in residuals.iter().enumerate() {
514            weights[i] = self.robust_loss.weight(residual);
515        }
516
517        Ok((residuals, weights))
518    }
519
520    /// Get outlier scores (lower weight means more likely to be outlier)
521    pub fn outlier_scores(&self) -> Option<Array1<Float>> {
522        self.sample_weights_.as_ref().map(|weights| {
523            // Convert weights to outlier scores (1 - weight)
524            weights.mapv(|w| 1.0 - w)
525        })
526    }
527}
528
529#[allow(non_snake_case)]
530#[cfg(test)]
531mod tests {
532    use super::*;
533    use scirs2_core::ndarray::array;
534
535    #[test]
536    fn test_robust_kernel_ridge_regression() {
537        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [10.0, 10.0]]; // Last point is outlier
538        let y = array![1.0, 2.0, 3.0, 100.0]; // Last target is outlier
539
540        let approximation = ApproximationMethod::RandomFourierFeatures {
541            n_components: 20,
542            gamma: 0.1,
543        };
544
545        let robust_krr = RobustKernelRidgeRegression::new(approximation)
546            .alpha(0.1)
547            .robust_loss(RobustLoss::Huber { delta: 1.0 });
548
549        let fitted = robust_krr.fit(&x, &y).unwrap();
550        let predictions = fitted.predict(&x).unwrap();
551
552        assert_eq!(predictions.len(), 4);
553
554        // Check that predictions are reasonable
555        for pred in predictions.iter() {
556            assert!(pred.is_finite());
557        }
558
559        // Check that outlier has lower weight
560        let sample_weights = fitted.sample_weights().unwrap();
561        assert!(sample_weights[3] < sample_weights[0]); // Outlier should have lower weight
562    }
563
564    #[test]
565    fn test_different_robust_losses() {
566        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
567        let y = array![1.0, 2.0, 3.0];
568
569        let approximation = ApproximationMethod::RandomFourierFeatures {
570            n_components: 10,
571            gamma: 1.0,
572        };
573
574        let loss_functions = vec![
575            RobustLoss::Huber { delta: 1.0 },
576            RobustLoss::EpsilonInsensitive { epsilon: 0.1 },
577            RobustLoss::Quantile { tau: 0.5 },
578            RobustLoss::Tukey { c: 4.685 },
579            RobustLoss::Cauchy { sigma: 1.0 },
580        ];
581
582        for loss in loss_functions {
583            let robust_krr = RobustKernelRidgeRegression::new(approximation.clone())
584                .alpha(0.1)
585                .robust_loss(loss);
586
587            let fitted = robust_krr.fit(&x, &y).unwrap();
588            let predictions = fitted.predict(&x).unwrap();
589
590            assert_eq!(predictions.len(), 3);
591        }
592    }
593
594    #[test]
595    fn test_robust_loss_functions() {
596        let huber = RobustLoss::Huber { delta: 1.0 };
597
598        // Test loss computation
599        assert_eq!(huber.loss(0.5), 0.125); // Quadratic region
600        assert_eq!(huber.loss(2.0), 1.5); // Linear region
601
602        // Test weight computation
603        assert_eq!(huber.weight(0.5), 1.0); // Quadratic region
604        assert_eq!(huber.weight(2.0), 0.5); // Linear region
605    }
606
607    #[test]
608    fn test_robust_outlier_detection() {
609        let x = array![[1.0], [2.0], [3.0], [100.0]]; // Last point is outlier
610        let y = array![1.0, 2.0, 3.0, 100.0]; // Last target is outlier
611
612        let approximation = ApproximationMethod::RandomFourierFeatures {
613            n_components: 10,
614            gamma: 1.0,
615        };
616
617        let robust_krr = RobustKernelRidgeRegression::new(approximation)
618            .alpha(0.1)
619            .robust_loss(RobustLoss::Huber { delta: 1.0 });
620
621        let fitted = robust_krr.fit(&x, &y).unwrap();
622        let outlier_scores = fitted.outlier_scores().unwrap();
623
624        // Outlier should have higher score
625        assert!(outlier_scores[3] > outlier_scores[0]);
626        assert!(outlier_scores[3] > outlier_scores[1]);
627        assert!(outlier_scores[3] > outlier_scores[2]);
628    }
629
630    #[test]
631    fn test_robust_convergence() {
632        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
633        let y = array![1.0, 2.0, 3.0];
634
635        let approximation = ApproximationMethod::RandomFourierFeatures {
636            n_components: 10,
637            gamma: 1.0,
638        };
639
640        let robust_krr = RobustKernelRidgeRegression::new(approximation)
641            .alpha(0.1)
642            .robust_loss(RobustLoss::Huber { delta: 1.0 })
643            .max_iter(5) // Small number of iterations
644            .tolerance(1e-3);
645
646        let fitted = robust_krr.fit(&x, &y).unwrap();
647        let predictions = fitted.predict(&x).unwrap();
648
649        assert_eq!(predictions.len(), 3);
650        // Should converge even with few iterations for this simple case
651        for pred in predictions.iter() {
652            assert!(pred.is_finite());
653        }
654    }
655
656    #[test]
657    fn test_robust_reproducibility() {
658        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
659        let y = array![1.0, 2.0, 3.0];
660
661        let approximation = ApproximationMethod::RandomFourierFeatures {
662            n_components: 10,
663            gamma: 1.0,
664        };
665
666        let robust_krr1 = RobustKernelRidgeRegression::new(approximation.clone())
667            .alpha(0.1)
668            .robust_loss(RobustLoss::Huber { delta: 1.0 })
669            .random_state(42);
670        let fitted1 = robust_krr1.fit(&x, &y).unwrap();
671        let pred1 = fitted1.predict(&x).unwrap();
672
673        let robust_krr2 = RobustKernelRidgeRegression::new(approximation)
674            .alpha(0.1)
675            .robust_loss(RobustLoss::Huber { delta: 1.0 })
676            .random_state(42);
677        let fitted2 = robust_krr2.fit(&x, &y).unwrap();
678        let pred2 = fitted2.predict(&x).unwrap();
679
680        assert_eq!(pred1.len(), pred2.len());
681        for i in 0..pred1.len() {
682            assert!((pred1[i] - pred2[i]).abs() < 1e-10);
683        }
684    }
685
686    #[test]
687    fn test_robust_loss_edge_cases() {
688        let losses = vec![
689            RobustLoss::Huber { delta: 1.0 },
690            RobustLoss::EpsilonInsensitive { epsilon: 0.1 },
691            RobustLoss::Quantile { tau: 0.5 },
692            RobustLoss::Tukey { c: 4.685 },
693            RobustLoss::Cauchy { sigma: 1.0 },
694            RobustLoss::Logistic { scale: 1.0 },
695            RobustLoss::Fair { c: 1.0 },
696            RobustLoss::Welsch { c: 1.0 },
697        ];
698
699        for loss in losses {
700            // Test with zero residual
701            let loss_zero = loss.loss(0.0);
702            let weight_zero = loss.weight(0.0);
703
704            assert!(loss_zero >= 0.0);
705            assert!(weight_zero >= 0.0);
706            assert!(weight_zero <= 1.5); // Most weights should be <= 1.0, allowing some tolerance
707
708            // Test with non-zero residual
709            let loss_nonzero = loss.loss(1.0);
710            let weight_nonzero = loss.weight(1.0);
711
712            assert!(loss_nonzero >= 0.0);
713            assert!(weight_nonzero >= 0.0);
714        }
715    }
716
717    #[test]
718    fn test_custom_robust_loss() {
719        let custom_loss = RobustLoss::Custom {
720            loss_fn: |r| r * r, // Simple quadratic loss
721            weight_fn: |_| 1.0, // Constant weight
722        };
723
724        assert_eq!(custom_loss.loss(2.0), 4.0);
725        assert_eq!(custom_loss.weight(5.0), 1.0);
726    }
727}