Skip to main content

sklears_kernel_approximation/kernel_ridge_regression/
robust_regression.rs

1//! Robust Kernel Ridge Regression Implementation
2//!
3//! This module implements robust variants of kernel ridge regression that are resistant
4//! to outliers and noise in the data. Multiple robust loss functions are supported,
5//! and the optimization is performed using iteratively reweighted least squares (IRLS).
6
7use crate::{
8    FastfoodTransform, Nystroem, RBFSampler, StructuredRandomFeatures, Trained, Untrained,
9};
10use scirs2_linalg::compat::ArrayLinalgExt;
11// Removed SVD import - using ArrayLinalgExt for both solve and svd methods
12
13use scirs2_core::ndarray::{Array1, Array2};
14use sklears_core::error::{Result, SklearsError};
15use sklears_core::prelude::{Estimator, Fit, Float, Predict};
16use std::marker::PhantomData;
17
18use super::core_types::*;
19
20/// Robust kernel ridge regression
21///
22/// This implements robust variants of kernel ridge regression that are resistant
23/// to outliers and noise in the data. Multiple robust loss functions are supported.
24///
25/// # Parameters
26///
27/// * `approximation_method` - Method for kernel approximation
28/// * `alpha` - Regularization strength
29/// * `robust_loss` - Robust loss function to use
30/// * `solver` - Method for solving the optimization problem
31/// * `max_iter` - Maximum number of iterations for robust optimization
32/// * `tolerance` - Convergence tolerance
33/// * `random_state` - Random seed for reproducibility
34#[derive(Debug, Clone)]
35pub struct RobustKernelRidgeRegression<State = Untrained> {
36    pub approximation_method: ApproximationMethod,
37    pub alpha: Float,
38    pub robust_loss: RobustLoss,
39    pub solver: Solver,
40    pub max_iter: usize,
41    pub tolerance: Float,
42    pub random_state: Option<u64>,
43
44    // Fitted parameters
45    weights_: Option<Array1<Float>>,
46    feature_transformer_: Option<FeatureTransformer>,
47    sample_weights_: Option<Array1<Float>>, // Adaptive weights for robustness
48
49    _state: PhantomData<State>,
50}
51
52/// Robust loss functions for kernel ridge regression
53#[derive(Debug, Clone)]
54pub enum RobustLoss {
55    /// Huber loss - quadratic for small residuals, linear for large ones
56    Huber { delta: Float },
57    /// Epsilon-insensitive loss (used in SVR)
58    EpsilonInsensitive { epsilon: Float },
59    /// Quantile loss for quantile regression
60    Quantile { tau: Float },
61    /// Tukey's biweight loss
62    Tukey { c: Float },
63    /// Cauchy loss
64    Cauchy { sigma: Float },
65    /// Logistic loss
66    Logistic { scale: Float },
67    /// Fair loss
68    Fair { c: Float },
69    /// Welsch loss
70    Welsch { c: Float },
71    /// Custom robust loss function
72    Custom {
73        loss_fn: fn(Float) -> Float,
74        weight_fn: fn(Float) -> Float,
75    },
76}
77
78impl Default for RobustLoss {
79    fn default() -> Self {
80        Self::Huber { delta: 1.0 }
81    }
82}
83
84impl RobustLoss {
85    /// Compute the loss for a given residual
86    pub fn loss(&self, residual: Float) -> Float {
87        let abs_r = residual.abs();
88        match self {
89            RobustLoss::Huber { delta } => {
90                if abs_r <= *delta {
91                    0.5 * residual * residual
92                } else {
93                    delta * (abs_r - 0.5 * delta)
94                }
95            }
96            RobustLoss::EpsilonInsensitive { epsilon } => (abs_r - epsilon).max(0.0),
97            RobustLoss::Quantile { tau } => {
98                if residual >= 0.0 {
99                    tau * residual
100                } else {
101                    (tau - 1.0) * residual
102                }
103            }
104            RobustLoss::Tukey { c } => {
105                if abs_r <= *c {
106                    let r_norm = residual / c;
107                    (c * c / 6.0) * (1.0 - (1.0 - r_norm * r_norm).powi(3))
108                } else {
109                    c * c / 6.0
110                }
111            }
112            RobustLoss::Cauchy { sigma } => {
113                (sigma * sigma / 2.0) * ((1.0 + (residual / sigma).powi(2)).ln())
114            }
115            RobustLoss::Logistic { scale } => scale * (1.0 + (-abs_r / scale).exp()).ln(),
116            RobustLoss::Fair { c } => c * (abs_r / c - (1.0 + abs_r / c).ln()),
117            RobustLoss::Welsch { c } => (c * c / 2.0) * (1.0 - (-((residual / c).powi(2))).exp()),
118            RobustLoss::Custom { loss_fn, .. } => loss_fn(residual),
119        }
120    }
121
122    /// Compute the weight for iteratively reweighted least squares
123    pub fn weight(&self, residual: Float) -> Float {
124        let abs_r = residual.abs();
125        if abs_r < 1e-10 {
126            return 1.0; // Avoid division by zero
127        }
128
129        match self {
130            RobustLoss::Huber { delta } => {
131                if abs_r <= *delta {
132                    1.0
133                } else {
134                    delta / abs_r
135                }
136            }
137            RobustLoss::EpsilonInsensitive { epsilon } => {
138                if abs_r <= *epsilon {
139                    0.0
140                } else {
141                    1.0
142                }
143            }
144            RobustLoss::Quantile { tau } => {
145                // For quantile regression, weights are constant
146                if residual >= 0.0 {
147                    *tau
148                } else {
149                    1.0 - tau
150                }
151            }
152            RobustLoss::Tukey { c } => {
153                if abs_r <= *c {
154                    let r_norm = residual / c;
155                    (1.0 - r_norm * r_norm).powi(2)
156                } else {
157                    0.0
158                }
159            }
160            RobustLoss::Cauchy { sigma } => 1.0 / (1.0 + (residual / sigma).powi(2)),
161            RobustLoss::Logistic { scale } => {
162                let exp_term = (-abs_r / scale).exp();
163                exp_term / (1.0 + exp_term)
164            }
165            RobustLoss::Fair { c } => 1.0 / (1.0 + abs_r / c),
166            RobustLoss::Welsch { c } => (-((residual / c).powi(2))).exp(),
167            RobustLoss::Custom { weight_fn, .. } => weight_fn(residual),
168        }
169    }
170}
171
172impl RobustKernelRidgeRegression<Untrained> {
173    /// Create a new robust kernel ridge regression model
174    pub fn new(approximation_method: ApproximationMethod) -> Self {
175        Self {
176            approximation_method,
177            alpha: 1.0,
178            robust_loss: RobustLoss::default(),
179            solver: Solver::Direct,
180            max_iter: 100,
181            tolerance: 1e-6,
182            random_state: None,
183            weights_: None,
184            feature_transformer_: None,
185            sample_weights_: None,
186            _state: PhantomData,
187        }
188    }
189
190    /// Set regularization parameter
191    pub fn alpha(mut self, alpha: Float) -> Self {
192        self.alpha = alpha;
193        self
194    }
195
196    /// Set robust loss function
197    pub fn robust_loss(mut self, robust_loss: RobustLoss) -> Self {
198        self.robust_loss = robust_loss;
199        self
200    }
201
202    /// Set solver method
203    pub fn solver(mut self, solver: Solver) -> Self {
204        self.solver = solver;
205        self
206    }
207
208    /// Set maximum iterations for robust optimization
209    pub fn max_iter(mut self, max_iter: usize) -> Self {
210        self.max_iter = max_iter;
211        self
212    }
213
214    /// Set convergence tolerance
215    pub fn tolerance(mut self, tolerance: Float) -> Self {
216        self.tolerance = tolerance;
217        self
218    }
219
220    /// Set random state for reproducibility
221    pub fn random_state(mut self, seed: u64) -> Self {
222        self.random_state = Some(seed);
223        self
224    }
225}
226
227impl Estimator for RobustKernelRidgeRegression<Untrained> {
228    type Config = ();
229    type Error = SklearsError;
230    type Float = Float;
231
232    fn config(&self) -> &Self::Config {
233        &()
234    }
235}
236
237impl Fit<Array2<Float>, Array1<Float>> for RobustKernelRidgeRegression<Untrained> {
238    type Fitted = RobustKernelRidgeRegression<Trained>;
239
240    fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
241        if x.nrows() != y.len() {
242            return Err(SklearsError::InvalidInput(
243                "Number of samples must match".to_string(),
244            ));
245        }
246
247        // Fit the feature transformer
248        let feature_transformer = self.fit_feature_transformer(x)?;
249        let x_transformed = feature_transformer.transform(x)?;
250
251        // Solve robust regression using iteratively reweighted least squares (IRLS)
252        let (weights, sample_weights) = self.solve_robust_regression(&x_transformed, y)?;
253
254        Ok(RobustKernelRidgeRegression {
255            approximation_method: self.approximation_method,
256            alpha: self.alpha,
257            robust_loss: self.robust_loss,
258            solver: self.solver,
259            max_iter: self.max_iter,
260            tolerance: self.tolerance,
261            random_state: self.random_state,
262            weights_: Some(weights),
263            feature_transformer_: Some(feature_transformer),
264            sample_weights_: Some(sample_weights),
265            _state: PhantomData,
266        })
267    }
268}
269
270impl RobustKernelRidgeRegression<Untrained> {
271    /// Fit the feature transformer based on the approximation method
272    fn fit_feature_transformer(&self, x: &Array2<Float>) -> Result<FeatureTransformer> {
273        match &self.approximation_method {
274            ApproximationMethod::Nystroem {
275                kernel,
276                n_components,
277                sampling_strategy,
278            } => {
279                let mut nystroem = Nystroem::new(kernel.clone(), *n_components)
280                    .sampling_strategy(sampling_strategy.clone());
281                if let Some(seed) = self.random_state {
282                    nystroem = nystroem.random_state(seed);
283                }
284                let fitted = nystroem.fit(x, &())?;
285                Ok(FeatureTransformer::Nystroem(fitted))
286            }
287            ApproximationMethod::RandomFourierFeatures {
288                n_components,
289                gamma,
290            } => {
291                let mut rff = RBFSampler::new(*n_components).gamma(*gamma);
292                if let Some(seed) = self.random_state {
293                    rff = rff.random_state(seed);
294                }
295                let fitted = rff.fit(x, &())?;
296                Ok(FeatureTransformer::RBFSampler(fitted))
297            }
298            ApproximationMethod::StructuredRandomFeatures {
299                n_components,
300                gamma,
301            } => {
302                let mut srf = StructuredRandomFeatures::new(*n_components).gamma(*gamma);
303                if let Some(seed) = self.random_state {
304                    srf = srf.random_state(seed);
305                }
306                let fitted = srf.fit(x, &())?;
307                Ok(FeatureTransformer::StructuredRFF(fitted))
308            }
309            ApproximationMethod::Fastfood {
310                n_components,
311                gamma,
312            } => {
313                let mut fastfood = FastfoodTransform::new(*n_components).gamma(*gamma);
314                if let Some(seed) = self.random_state {
315                    fastfood = fastfood.random_state(seed);
316                }
317                let fitted = fastfood.fit(x, &())?;
318                Ok(FeatureTransformer::Fastfood(fitted))
319            }
320        }
321    }
322
323    /// Solve robust regression using iteratively reweighted least squares
324    fn solve_robust_regression(
325        &self,
326        x: &Array2<Float>,
327        y: &Array1<Float>,
328    ) -> Result<(Array1<Float>, Array1<Float>)> {
329        let n_samples = x.nrows();
330        let n_features = x.ncols();
331
332        // Initialize with ordinary least squares solution
333        let x_f64 = Array2::from_shape_fn(x.dim(), |(i, j)| x[[i, j]]);
334        let y_f64 = Array1::from_vec(y.iter().copied().collect());
335
336        let xtx = x_f64.t().dot(&x_f64);
337        let regularized_xtx = xtx + Array2::<f64>::eye(n_features) * self.alpha;
338        let xty = x_f64.t().dot(&y_f64);
339        let mut weights_f64 =
340            regularized_xtx
341                .solve(&xty)
342                .map_err(|e| SklearsError::InvalidParameter {
343                    name: "regularization".to_string(),
344                    reason: format!("Initial linear system solving failed: {:?}", e),
345                })?;
346
347        let mut sample_weights = Array1::ones(n_samples);
348        let mut prev_weights = weights_f64.clone();
349
350        // Iteratively reweighted least squares
351        for _iter in 0..self.max_iter {
352            // Compute residuals
353            let predictions = x_f64.dot(&weights_f64);
354            let residuals = &y_f64 - &predictions;
355
356            // Update sample weights based on residuals
357            for (i, &residual) in residuals.iter().enumerate() {
358                sample_weights[i] = self.robust_loss.weight(residual as Float);
359            }
360
361            // Solve weighted least squares
362            let mut weighted_xtx = Array2::zeros((n_features, n_features));
363            let mut weighted_xty = Array1::zeros(n_features);
364
365            for i in 0..n_samples {
366                let weight = sample_weights[i];
367                let x_row = x_f64.row(i);
368
369                // X^T W X
370                for j in 0..n_features {
371                    for k in 0..n_features {
372                        weighted_xtx[[j, k]] += weight * x_row[j] * x_row[k];
373                    }
374                }
375
376                // X^T W y
377                for j in 0..n_features {
378                    weighted_xty[j] += weight * x_row[j] * y_f64[i];
379                }
380            }
381
382            // Add regularization
383            weighted_xtx += &(Array2::eye(n_features) * self.alpha);
384
385            // Solve the weighted system
386            weights_f64 = match self.solver {
387                Solver::Direct => weighted_xtx.solve(&weighted_xty).map_err(|e| {
388                    SklearsError::InvalidParameter {
389                        name: "weighted_system".to_string(),
390                        reason: format!("Weighted linear system solving failed: {:?}", e),
391                    }
392                })?,
393                Solver::SVD => {
394                    let (u, s, vt) =
395                        weighted_xtx
396                            .svd(true)
397                            .map_err(|e| SklearsError::InvalidParameter {
398                                name: "svd".to_string(),
399                                reason: format!("SVD decomposition failed: {:?}", e),
400                            })?;
401                    let ut_b = u.t().dot(&weighted_xty);
402                    let s_inv = s.mapv(|x| if x > 1e-10 { 1.0 / x } else { 0.0 });
403                    let y_svd = ut_b * s_inv;
404                    vt.t().dot(&y_svd)
405                }
406                Solver::ConjugateGradient { max_iter, tol } => {
407                    self.solve_cg_weighted(&weighted_xtx, &weighted_xty, max_iter, tol)?
408                }
409            };
410
411            // Check convergence
412            let weight_change = (&weights_f64 - &prev_weights).mapv(|x| x.abs()).sum();
413            if weight_change < self.tolerance {
414                break;
415            }
416
417            prev_weights = weights_f64.clone();
418        }
419
420        // Convert back to Float
421        let weights = Array1::from_vec(weights_f64.iter().map(|&val| val as Float).collect());
422        let sample_weights_float =
423            Array1::from_vec(sample_weights.iter().map(|&val| val as Float).collect());
424
425        Ok((weights, sample_weights_float))
426    }
427
428    /// Solve using conjugate gradient method
429    fn solve_cg_weighted(
430        &self,
431        a: &Array2<f64>,
432        b: &Array1<f64>,
433        max_iter: usize,
434        tol: f64,
435    ) -> Result<Array1<f64>> {
436        let n = b.len();
437        let mut x = Array1::zeros(n);
438        let mut r = b - &a.dot(&x);
439        let mut p = r.clone();
440        let mut rsold = r.dot(&r);
441
442        for _iter in 0..max_iter {
443            let ap = a.dot(&p);
444            let alpha = rsold / p.dot(&ap);
445
446            x = x + &p * alpha;
447            r = r - &ap * alpha;
448
449            let rsnew = r.dot(&r);
450
451            if rsnew.sqrt() < tol {
452                break;
453            }
454
455            let beta = rsnew / rsold;
456            p = &r + &p * beta;
457            rsold = rsnew;
458        }
459
460        Ok(x)
461    }
462}
463
464impl Predict<Array2<Float>, Array1<Float>> for RobustKernelRidgeRegression<Trained> {
465    fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
466        let feature_transformer =
467            self.feature_transformer_
468                .as_ref()
469                .ok_or_else(|| SklearsError::NotFitted {
470                    operation: "predict".to_string(),
471                })?;
472
473        let weights = self
474            .weights_
475            .as_ref()
476            .ok_or_else(|| SklearsError::NotFitted {
477                operation: "predict".to_string(),
478            })?;
479
480        let x_transformed = feature_transformer.transform(x)?;
481        let predictions = x_transformed.dot(weights);
482
483        Ok(predictions)
484    }
485}
486
487impl RobustKernelRidgeRegression<Trained> {
488    /// Get the fitted weights
489    pub fn weights(&self) -> Option<&Array1<Float>> {
490        self.weights_.as_ref()
491    }
492
493    /// Get the sample weights from robust fitting
494    pub fn sample_weights(&self) -> Option<&Array1<Float>> {
495        self.sample_weights_.as_ref()
496    }
497
498    /// Compute robust residuals and their weights
499    pub fn robust_residuals(
500        &self,
501        x: &Array2<Float>,
502        y: &Array1<Float>,
503    ) -> Result<(Array1<Float>, Array1<Float>)> {
504        let predictions = self.predict(x)?;
505        let residuals = y - &predictions;
506
507        let mut weights = Array1::zeros(residuals.len());
508        for (i, &residual) in residuals.iter().enumerate() {
509            weights[i] = self.robust_loss.weight(residual);
510        }
511
512        Ok((residuals, weights))
513    }
514
515    /// Get outlier scores (lower weight means more likely to be outlier)
516    pub fn outlier_scores(&self) -> Option<Array1<Float>> {
517        self.sample_weights_.as_ref().map(|weights| {
518            // Convert weights to outlier scores (1 - weight)
519            weights.mapv(|w| 1.0 - w)
520        })
521    }
522}
523
524#[allow(non_snake_case)]
525#[cfg(test)]
526mod tests {
527    use super::*;
528    use scirs2_core::ndarray::array;
529
530    #[test]
531    fn test_robust_kernel_ridge_regression() {
532        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [10.0, 10.0]]; // Last point is outlier
533        let y = array![1.0, 2.0, 3.0, 100.0]; // Last target is outlier
534
535        let approximation = ApproximationMethod::RandomFourierFeatures {
536            n_components: 20,
537            gamma: 0.1,
538        };
539
540        let robust_krr = RobustKernelRidgeRegression::new(approximation)
541            .alpha(0.1)
542            .robust_loss(RobustLoss::Huber { delta: 1.0 });
543
544        let fitted = robust_krr.fit(&x, &y).expect("operation should succeed");
545        let predictions = fitted.predict(&x).expect("operation should succeed");
546
547        assert_eq!(predictions.len(), 4);
548
549        // Check that predictions are reasonable
550        for pred in predictions.iter() {
551            assert!(pred.is_finite());
552        }
553
554        // Check that outlier has lower weight
555        let sample_weights = fitted.sample_weights().expect("operation should succeed");
556        assert!(sample_weights[3] < sample_weights[0]); // Outlier should have lower weight
557    }
558
559    #[test]
560    fn test_different_robust_losses() {
561        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
562        let y = array![1.0, 2.0, 3.0];
563
564        let approximation = ApproximationMethod::RandomFourierFeatures {
565            n_components: 10,
566            gamma: 1.0,
567        };
568
569        let loss_functions = vec![
570            RobustLoss::Huber { delta: 1.0 },
571            RobustLoss::EpsilonInsensitive { epsilon: 0.1 },
572            RobustLoss::Quantile { tau: 0.5 },
573            RobustLoss::Tukey { c: 4.685 },
574            RobustLoss::Cauchy { sigma: 1.0 },
575        ];
576
577        for loss in loss_functions {
578            let robust_krr = RobustKernelRidgeRegression::new(approximation.clone())
579                .alpha(0.1)
580                .robust_loss(loss);
581
582            let fitted = robust_krr.fit(&x, &y).expect("operation should succeed");
583            let predictions = fitted.predict(&x).expect("operation should succeed");
584
585            assert_eq!(predictions.len(), 3);
586        }
587    }
588
589    #[test]
590    fn test_robust_loss_functions() {
591        let huber = RobustLoss::Huber { delta: 1.0 };
592
593        // Test loss computation
594        assert_eq!(huber.loss(0.5), 0.125); // Quadratic region
595        assert_eq!(huber.loss(2.0), 1.5); // Linear region
596
597        // Test weight computation
598        assert_eq!(huber.weight(0.5), 1.0); // Quadratic region
599        assert_eq!(huber.weight(2.0), 0.5); // Linear region
600    }
601
602    #[test]
603    fn test_robust_outlier_detection() {
604        let x = array![[1.0], [2.0], [3.0], [100.0]]; // Last point is outlier
605        let y = array![1.0, 2.0, 3.0, 100.0]; // Last target is outlier
606
607        let approximation = ApproximationMethod::RandomFourierFeatures {
608            n_components: 10,
609            gamma: 1.0,
610        };
611
612        let robust_krr = RobustKernelRidgeRegression::new(approximation)
613            .alpha(0.1)
614            .robust_loss(RobustLoss::Huber { delta: 1.0 });
615
616        let fitted = robust_krr.fit(&x, &y).expect("operation should succeed");
617        let outlier_scores = fitted.outlier_scores().expect("operation should succeed");
618
619        // Outlier should have higher score
620        assert!(outlier_scores[3] > outlier_scores[0]);
621        assert!(outlier_scores[3] > outlier_scores[1]);
622        assert!(outlier_scores[3] > outlier_scores[2]);
623    }
624
625    #[test]
626    fn test_robust_convergence() {
627        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
628        let y = array![1.0, 2.0, 3.0];
629
630        let approximation = ApproximationMethod::RandomFourierFeatures {
631            n_components: 10,
632            gamma: 1.0,
633        };
634
635        let robust_krr = RobustKernelRidgeRegression::new(approximation)
636            .alpha(0.1)
637            .robust_loss(RobustLoss::Huber { delta: 1.0 })
638            .max_iter(5) // Small number of iterations
639            .tolerance(1e-3);
640
641        let fitted = robust_krr.fit(&x, &y).expect("operation should succeed");
642        let predictions = fitted.predict(&x).expect("operation should succeed");
643
644        assert_eq!(predictions.len(), 3);
645        // Should converge even with few iterations for this simple case
646        for pred in predictions.iter() {
647            assert!(pred.is_finite());
648        }
649    }
650
651    #[test]
652    fn test_robust_reproducibility() {
653        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
654        let y = array![1.0, 2.0, 3.0];
655
656        let approximation = ApproximationMethod::RandomFourierFeatures {
657            n_components: 10,
658            gamma: 1.0,
659        };
660
661        let robust_krr1 = RobustKernelRidgeRegression::new(approximation.clone())
662            .alpha(0.1)
663            .robust_loss(RobustLoss::Huber { delta: 1.0 })
664            .random_state(42);
665        let fitted1 = robust_krr1.fit(&x, &y).expect("operation should succeed");
666        let pred1 = fitted1.predict(&x).expect("operation should succeed");
667
668        let robust_krr2 = RobustKernelRidgeRegression::new(approximation)
669            .alpha(0.1)
670            .robust_loss(RobustLoss::Huber { delta: 1.0 })
671            .random_state(42);
672        let fitted2 = robust_krr2.fit(&x, &y).expect("operation should succeed");
673        let pred2 = fitted2.predict(&x).expect("operation should succeed");
674
675        assert_eq!(pred1.len(), pred2.len());
676        for i in 0..pred1.len() {
677            assert!((pred1[i] - pred2[i]).abs() < 1e-10);
678        }
679    }
680
681    #[test]
682    fn test_robust_loss_edge_cases() {
683        let losses = vec![
684            RobustLoss::Huber { delta: 1.0 },
685            RobustLoss::EpsilonInsensitive { epsilon: 0.1 },
686            RobustLoss::Quantile { tau: 0.5 },
687            RobustLoss::Tukey { c: 4.685 },
688            RobustLoss::Cauchy { sigma: 1.0 },
689            RobustLoss::Logistic { scale: 1.0 },
690            RobustLoss::Fair { c: 1.0 },
691            RobustLoss::Welsch { c: 1.0 },
692        ];
693
694        for loss in losses {
695            // Test with zero residual
696            let loss_zero = loss.loss(0.0);
697            let weight_zero = loss.weight(0.0);
698
699            assert!(loss_zero >= 0.0);
700            assert!(weight_zero >= 0.0);
701            assert!(weight_zero <= 1.5); // Most weights should be <= 1.0, allowing some tolerance
702
703            // Test with non-zero residual
704            let loss_nonzero = loss.loss(1.0);
705            let weight_nonzero = loss.weight(1.0);
706
707            assert!(loss_nonzero >= 0.0);
708            assert!(weight_nonzero >= 0.0);
709        }
710    }
711
712    #[test]
713    fn test_custom_robust_loss() {
714        let custom_loss = RobustLoss::Custom {
715            loss_fn: |r| r * r, // Simple quadratic loss
716            weight_fn: |_| 1.0, // Constant weight
717        };
718
719        assert_eq!(custom_loss.loss(2.0), 4.0);
720        assert_eq!(custom_loss.weight(5.0), 1.0);
721    }
722}