sklears_kernel_approximation/kernel_ridge_regression/
robust_regression.rs

1//! Robust Kernel Ridge Regression Implementation
2//!
3//! This module implements robust variants of kernel ridge regression that are resistant
4//! to outliers and noise in the data. Multiple robust loss functions are supported,
5//! and the optimization is performed using iteratively reweighted least squares (IRLS).
6
7use crate::{
8    FastfoodTransform, Nystroem, RBFSampler, StructuredRandomFeatures, Trained, Untrained,
9};
10use scirs2_linalg::compat::ArrayLinalgExt;
11// Removed SVD import - using ArrayLinalgExt for both solve and svd methods
12
13use scirs2_core::ndarray::{Array1, Array2};
14use sklears_core::error::{Result, SklearsError};
15use sklears_core::prelude::{Estimator, Fit, Float, Predict};
16use std::marker::PhantomData;
17
18use super::core_types::*;
19
20/// Robust kernel ridge regression
21///
22/// This implements robust variants of kernel ridge regression that are resistant
23/// to outliers and noise in the data. Multiple robust loss functions are supported.
24///
25/// # Parameters
26///
27/// * `approximation_method` - Method for kernel approximation
28/// * `alpha` - Regularization strength
29/// * `robust_loss` - Robust loss function to use
30/// * `solver` - Method for solving the optimization problem
31/// * `max_iter` - Maximum number of iterations for robust optimization
32/// * `tolerance` - Convergence tolerance
33/// * `random_state` - Random seed for reproducibility
34///
35/// # Examples
36///
37/// ```rust,ignore
38/// use sklears_kernel_approximation::kernel_ridge_regression::{
39#[derive(Debug, Clone)]
40pub struct RobustKernelRidgeRegression<State = Untrained> {
41    pub approximation_method: ApproximationMethod,
42    pub alpha: Float,
43    pub robust_loss: RobustLoss,
44    pub solver: Solver,
45    pub max_iter: usize,
46    pub tolerance: Float,
47    pub random_state: Option<u64>,
48
49    // Fitted parameters
50    weights_: Option<Array1<Float>>,
51    feature_transformer_: Option<FeatureTransformer>,
52    sample_weights_: Option<Array1<Float>>, // Adaptive weights for robustness
53
54    _state: PhantomData<State>,
55}
56
57/// Robust loss functions for kernel ridge regression
58#[derive(Debug, Clone)]
59pub enum RobustLoss {
60    /// Huber loss - quadratic for small residuals, linear for large ones
61    Huber { delta: Float },
62    /// Epsilon-insensitive loss (used in SVR)
63    EpsilonInsensitive { epsilon: Float },
64    /// Quantile loss for quantile regression
65    Quantile { tau: Float },
66    /// Tukey's biweight loss
67    Tukey { c: Float },
68    /// Cauchy loss
69    Cauchy { sigma: Float },
70    /// Logistic loss
71    Logistic { scale: Float },
72    /// Fair loss
73    Fair { c: Float },
74    /// Welsch loss
75    Welsch { c: Float },
76    /// Custom robust loss function
77    Custom {
78        loss_fn: fn(Float) -> Float,
79        weight_fn: fn(Float) -> Float,
80    },
81}
82
83impl Default for RobustLoss {
84    fn default() -> Self {
85        Self::Huber { delta: 1.0 }
86    }
87}
88
89impl RobustLoss {
90    /// Compute the loss for a given residual
91    pub fn loss(&self, residual: Float) -> Float {
92        let abs_r = residual.abs();
93        match self {
94            RobustLoss::Huber { delta } => {
95                if abs_r <= *delta {
96                    0.5 * residual * residual
97                } else {
98                    delta * (abs_r - 0.5 * delta)
99                }
100            }
101            RobustLoss::EpsilonInsensitive { epsilon } => (abs_r - epsilon).max(0.0),
102            RobustLoss::Quantile { tau } => {
103                if residual >= 0.0 {
104                    tau * residual
105                } else {
106                    (tau - 1.0) * residual
107                }
108            }
109            RobustLoss::Tukey { c } => {
110                if abs_r <= *c {
111                    let r_norm = residual / c;
112                    (c * c / 6.0) * (1.0 - (1.0 - r_norm * r_norm).powi(3))
113                } else {
114                    c * c / 6.0
115                }
116            }
117            RobustLoss::Cauchy { sigma } => {
118                (sigma * sigma / 2.0) * ((1.0 + (residual / sigma).powi(2)).ln())
119            }
120            RobustLoss::Logistic { scale } => scale * (1.0 + (-abs_r / scale).exp()).ln(),
121            RobustLoss::Fair { c } => c * (abs_r / c - (1.0 + abs_r / c).ln()),
122            RobustLoss::Welsch { c } => (c * c / 2.0) * (1.0 - (-((residual / c).powi(2))).exp()),
123            RobustLoss::Custom { loss_fn, .. } => loss_fn(residual),
124        }
125    }
126
127    /// Compute the weight for iteratively reweighted least squares
128    pub fn weight(&self, residual: Float) -> Float {
129        let abs_r = residual.abs();
130        if abs_r < 1e-10 {
131            return 1.0; // Avoid division by zero
132        }
133
134        match self {
135            RobustLoss::Huber { delta } => {
136                if abs_r <= *delta {
137                    1.0
138                } else {
139                    delta / abs_r
140                }
141            }
142            RobustLoss::EpsilonInsensitive { epsilon } => {
143                if abs_r <= *epsilon {
144                    0.0
145                } else {
146                    1.0
147                }
148            }
149            RobustLoss::Quantile { tau } => {
150                // For quantile regression, weights are constant
151                if residual >= 0.0 {
152                    *tau
153                } else {
154                    1.0 - tau
155                }
156            }
157            RobustLoss::Tukey { c } => {
158                if abs_r <= *c {
159                    let r_norm = residual / c;
160                    (1.0 - r_norm * r_norm).powi(2)
161                } else {
162                    0.0
163                }
164            }
165            RobustLoss::Cauchy { sigma } => 1.0 / (1.0 + (residual / sigma).powi(2)),
166            RobustLoss::Logistic { scale } => {
167                let exp_term = (-abs_r / scale).exp();
168                exp_term / (1.0 + exp_term)
169            }
170            RobustLoss::Fair { c } => 1.0 / (1.0 + abs_r / c),
171            RobustLoss::Welsch { c } => (-((residual / c).powi(2))).exp(),
172            RobustLoss::Custom { weight_fn, .. } => weight_fn(residual),
173        }
174    }
175}
176
177impl RobustKernelRidgeRegression<Untrained> {
178    /// Create a new robust kernel ridge regression model
179    pub fn new(approximation_method: ApproximationMethod) -> Self {
180        Self {
181            approximation_method,
182            alpha: 1.0,
183            robust_loss: RobustLoss::default(),
184            solver: Solver::Direct,
185            max_iter: 100,
186            tolerance: 1e-6,
187            random_state: None,
188            weights_: None,
189            feature_transformer_: None,
190            sample_weights_: None,
191            _state: PhantomData,
192        }
193    }
194
195    /// Set regularization parameter
196    pub fn alpha(mut self, alpha: Float) -> Self {
197        self.alpha = alpha;
198        self
199    }
200
201    /// Set robust loss function
202    pub fn robust_loss(mut self, robust_loss: RobustLoss) -> Self {
203        self.robust_loss = robust_loss;
204        self
205    }
206
207    /// Set solver method
208    pub fn solver(mut self, solver: Solver) -> Self {
209        self.solver = solver;
210        self
211    }
212
213    /// Set maximum iterations for robust optimization
214    pub fn max_iter(mut self, max_iter: usize) -> Self {
215        self.max_iter = max_iter;
216        self
217    }
218
219    /// Set convergence tolerance
220    pub fn tolerance(mut self, tolerance: Float) -> Self {
221        self.tolerance = tolerance;
222        self
223    }
224
225    /// Set random state for reproducibility
226    pub fn random_state(mut self, seed: u64) -> Self {
227        self.random_state = Some(seed);
228        self
229    }
230}
231
232impl Estimator for RobustKernelRidgeRegression<Untrained> {
233    type Config = ();
234    type Error = SklearsError;
235    type Float = Float;
236
237    fn config(&self) -> &Self::Config {
238        &()
239    }
240}
241
242impl Fit<Array2<Float>, Array1<Float>> for RobustKernelRidgeRegression<Untrained> {
243    type Fitted = RobustKernelRidgeRegression<Trained>;
244
245    fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
246        if x.nrows() != y.len() {
247            return Err(SklearsError::InvalidInput(
248                "Number of samples must match".to_string(),
249            ));
250        }
251
252        // Fit the feature transformer
253        let feature_transformer = self.fit_feature_transformer(x)?;
254        let x_transformed = feature_transformer.transform(x)?;
255
256        // Solve robust regression using iteratively reweighted least squares (IRLS)
257        let (weights, sample_weights) = self.solve_robust_regression(&x_transformed, y)?;
258
259        Ok(RobustKernelRidgeRegression {
260            approximation_method: self.approximation_method,
261            alpha: self.alpha,
262            robust_loss: self.robust_loss,
263            solver: self.solver,
264            max_iter: self.max_iter,
265            tolerance: self.tolerance,
266            random_state: self.random_state,
267            weights_: Some(weights),
268            feature_transformer_: Some(feature_transformer),
269            sample_weights_: Some(sample_weights),
270            _state: PhantomData,
271        })
272    }
273}
274
275impl RobustKernelRidgeRegression<Untrained> {
276    /// Fit the feature transformer based on the approximation method
277    fn fit_feature_transformer(&self, x: &Array2<Float>) -> Result<FeatureTransformer> {
278        match &self.approximation_method {
279            ApproximationMethod::Nystroem {
280                kernel,
281                n_components,
282                sampling_strategy,
283            } => {
284                let mut nystroem = Nystroem::new(kernel.clone(), *n_components)
285                    .sampling_strategy(sampling_strategy.clone());
286                if let Some(seed) = self.random_state {
287                    nystroem = nystroem.random_state(seed);
288                }
289                let fitted = nystroem.fit(x, &())?;
290                Ok(FeatureTransformer::Nystroem(fitted))
291            }
292            ApproximationMethod::RandomFourierFeatures {
293                n_components,
294                gamma,
295            } => {
296                let mut rff = RBFSampler::new(*n_components).gamma(*gamma);
297                if let Some(seed) = self.random_state {
298                    rff = rff.random_state(seed);
299                }
300                let fitted = rff.fit(x, &())?;
301                Ok(FeatureTransformer::RBFSampler(fitted))
302            }
303            ApproximationMethod::StructuredRandomFeatures {
304                n_components,
305                gamma,
306            } => {
307                let mut srf = StructuredRandomFeatures::new(*n_components).gamma(*gamma);
308                if let Some(seed) = self.random_state {
309                    srf = srf.random_state(seed);
310                }
311                let fitted = srf.fit(x, &())?;
312                Ok(FeatureTransformer::StructuredRFF(fitted))
313            }
314            ApproximationMethod::Fastfood {
315                n_components,
316                gamma,
317            } => {
318                let mut fastfood = FastfoodTransform::new(*n_components).gamma(*gamma);
319                if let Some(seed) = self.random_state {
320                    fastfood = fastfood.random_state(seed);
321                }
322                let fitted = fastfood.fit(x, &())?;
323                Ok(FeatureTransformer::Fastfood(fitted))
324            }
325        }
326    }
327
328    /// Solve robust regression using iteratively reweighted least squares
329    fn solve_robust_regression(
330        &self,
331        x: &Array2<Float>,
332        y: &Array1<Float>,
333    ) -> Result<(Array1<Float>, Array1<Float>)> {
334        let n_samples = x.nrows();
335        let n_features = x.ncols();
336
337        // Initialize with ordinary least squares solution
338        let x_f64 = Array2::from_shape_fn(x.dim(), |(i, j)| x[[i, j]]);
339        let y_f64 = Array1::from_vec(y.iter().copied().collect());
340
341        let xtx = x_f64.t().dot(&x_f64);
342        let regularized_xtx = xtx + Array2::<f64>::eye(n_features) * self.alpha;
343        let xty = x_f64.t().dot(&y_f64);
344        let mut weights_f64 =
345            regularized_xtx
346                .solve(&xty)
347                .map_err(|e| SklearsError::InvalidParameter {
348                    name: "regularization".to_string(),
349                    reason: format!("Initial linear system solving failed: {:?}", e),
350                })?;
351
352        let mut sample_weights = Array1::ones(n_samples);
353        let mut prev_weights = weights_f64.clone();
354
355        // Iteratively reweighted least squares
356        for _iter in 0..self.max_iter {
357            // Compute residuals
358            let predictions = x_f64.dot(&weights_f64);
359            let residuals = &y_f64 - &predictions;
360
361            // Update sample weights based on residuals
362            for (i, &residual) in residuals.iter().enumerate() {
363                sample_weights[i] = self.robust_loss.weight(residual as Float);
364            }
365
366            // Solve weighted least squares
367            let mut weighted_xtx = Array2::zeros((n_features, n_features));
368            let mut weighted_xty = Array1::zeros(n_features);
369
370            for i in 0..n_samples {
371                let weight = sample_weights[i];
372                let x_row = x_f64.row(i);
373
374                // X^T W X
375                for j in 0..n_features {
376                    for k in 0..n_features {
377                        weighted_xtx[[j, k]] += weight * x_row[j] * x_row[k];
378                    }
379                }
380
381                // X^T W y
382                for j in 0..n_features {
383                    weighted_xty[j] += weight * x_row[j] * y_f64[i];
384                }
385            }
386
387            // Add regularization
388            weighted_xtx += &(Array2::eye(n_features) * self.alpha);
389
390            // Solve the weighted system
391            weights_f64 = match self.solver {
392                Solver::Direct => weighted_xtx.solve(&weighted_xty).map_err(|e| {
393                    SklearsError::InvalidParameter {
394                        name: "weighted_system".to_string(),
395                        reason: format!("Weighted linear system solving failed: {:?}", e),
396                    }
397                })?,
398                Solver::SVD => {
399                    let (u, s, vt) =
400                        weighted_xtx
401                            .svd(true)
402                            .map_err(|e| SklearsError::InvalidParameter {
403                                name: "svd".to_string(),
404                                reason: format!("SVD decomposition failed: {:?}", e),
405                            })?;
406                    let ut_b = u.t().dot(&weighted_xty);
407                    let s_inv = s.mapv(|x| if x > 1e-10 { 1.0 / x } else { 0.0 });
408                    let y_svd = ut_b * s_inv;
409                    vt.t().dot(&y_svd)
410                }
411                Solver::ConjugateGradient { max_iter, tol } => {
412                    self.solve_cg_weighted(&weighted_xtx, &weighted_xty, max_iter, tol)?
413                }
414            };
415
416            // Check convergence
417            let weight_change = (&weights_f64 - &prev_weights).mapv(|x| x.abs()).sum();
418            if weight_change < self.tolerance {
419                break;
420            }
421
422            prev_weights = weights_f64.clone();
423        }
424
425        // Convert back to Float
426        let weights = Array1::from_vec(weights_f64.iter().map(|&val| val as Float).collect());
427        let sample_weights_float =
428            Array1::from_vec(sample_weights.iter().map(|&val| val as Float).collect());
429
430        Ok((weights, sample_weights_float))
431    }
432
433    /// Solve using conjugate gradient method
434    fn solve_cg_weighted(
435        &self,
436        a: &Array2<f64>,
437        b: &Array1<f64>,
438        max_iter: usize,
439        tol: f64,
440    ) -> Result<Array1<f64>> {
441        let n = b.len();
442        let mut x = Array1::zeros(n);
443        let mut r = b - &a.dot(&x);
444        let mut p = r.clone();
445        let mut rsold = r.dot(&r);
446
447        for _iter in 0..max_iter {
448            let ap = a.dot(&p);
449            let alpha = rsold / p.dot(&ap);
450
451            x = x + &p * alpha;
452            r = r - &ap * alpha;
453
454            let rsnew = r.dot(&r);
455
456            if rsnew.sqrt() < tol {
457                break;
458            }
459
460            let beta = rsnew / rsold;
461            p = &r + &p * beta;
462            rsold = rsnew;
463        }
464
465        Ok(x)
466    }
467}
468
469impl Predict<Array2<Float>, Array1<Float>> for RobustKernelRidgeRegression<Trained> {
470    fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
471        let feature_transformer =
472            self.feature_transformer_
473                .as_ref()
474                .ok_or_else(|| SklearsError::NotFitted {
475                    operation: "predict".to_string(),
476                })?;
477
478        let weights = self
479            .weights_
480            .as_ref()
481            .ok_or_else(|| SklearsError::NotFitted {
482                operation: "predict".to_string(),
483            })?;
484
485        let x_transformed = feature_transformer.transform(x)?;
486        let predictions = x_transformed.dot(weights);
487
488        Ok(predictions)
489    }
490}
491
492impl RobustKernelRidgeRegression<Trained> {
493    /// Get the fitted weights
494    pub fn weights(&self) -> Option<&Array1<Float>> {
495        self.weights_.as_ref()
496    }
497
498    /// Get the sample weights from robust fitting
499    pub fn sample_weights(&self) -> Option<&Array1<Float>> {
500        self.sample_weights_.as_ref()
501    }
502
503    /// Compute robust residuals and their weights
504    pub fn robust_residuals(
505        &self,
506        x: &Array2<Float>,
507        y: &Array1<Float>,
508    ) -> Result<(Array1<Float>, Array1<Float>)> {
509        let predictions = self.predict(x)?;
510        let residuals = y - &predictions;
511
512        let mut weights = Array1::zeros(residuals.len());
513        for (i, &residual) in residuals.iter().enumerate() {
514            weights[i] = self.robust_loss.weight(residual);
515        }
516
517        Ok((residuals, weights))
518    }
519
520    /// Get outlier scores (lower weight means more likely to be outlier)
521    pub fn outlier_scores(&self) -> Option<Array1<Float>> {
522        self.sample_weights_.as_ref().map(|weights| {
523            // Convert weights to outlier scores (1 - weight)
524            weights.mapv(|w| 1.0 - w)
525        })
526    }
527}
528
529#[allow(non_snake_case)]
530#[cfg(test)]
531mod tests {
532    use super::*;
533    use scirs2_core::ndarray::array;
534
535    #[test]
536    fn test_robust_kernel_ridge_regression() {
537        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [10.0, 10.0]]; // Last point is outlier
538        let y = array![1.0, 2.0, 3.0, 100.0]; // Last target is outlier
539
540        let approximation = ApproximationMethod::RandomFourierFeatures {
541            n_components: 20,
542            gamma: 0.1,
543        };
544
545        let robust_krr = RobustKernelRidgeRegression::new(approximation)
546            .alpha(0.1)
547            .robust_loss(RobustLoss::Huber { delta: 1.0 });
548
549        let fitted = robust_krr.fit(&x, &y).unwrap();
550        let predictions = fitted.predict(&x).unwrap();
551
552        assert_eq!(predictions.len(), 4);
553
554        // Check that predictions are reasonable
555        for pred in predictions.iter() {
556            assert!(pred.is_finite());
557        }
558
559        // Check that outlier has lower weight
560        let sample_weights = fitted.sample_weights().unwrap();
561        assert!(sample_weights[3] < sample_weights[0]); // Outlier should have lower weight
562    }
563
564    #[test]
565    fn test_different_robust_losses() {
566        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
567        let y = array![1.0, 2.0, 3.0];
568
569        let approximation = ApproximationMethod::RandomFourierFeatures {
570            n_components: 10,
571            gamma: 1.0,
572        };
573
574        let loss_functions = vec![
575            RobustLoss::Huber { delta: 1.0 },
576            RobustLoss::EpsilonInsensitive { epsilon: 0.1 },
577            RobustLoss::Quantile { tau: 0.5 },
578            RobustLoss::Tukey { c: 4.685 },
579            RobustLoss::Cauchy { sigma: 1.0 },
580        ];
581
582        for loss in loss_functions {
583            let robust_krr = RobustKernelRidgeRegression::new(approximation.clone())
584                .alpha(0.1)
585                .robust_loss(loss);
586
587            let fitted = robust_krr.fit(&x, &y).unwrap();
588            let predictions = fitted.predict(&x).unwrap();
589
590            assert_eq!(predictions.len(), 3);
591        }
592    }
593
594    #[test]
595    fn test_robust_loss_functions() {
596        let huber = RobustLoss::Huber { delta: 1.0 };
597
598        // Test loss computation
599        assert_eq!(huber.loss(0.5), 0.125); // Quadratic region
600        assert_eq!(huber.loss(2.0), 1.5); // Linear region
601
602        // Test weight computation
603        assert_eq!(huber.weight(0.5), 1.0); // Quadratic region
604        assert_eq!(huber.weight(2.0), 0.5); // Linear region
605    }
606
607    #[test]
608    fn test_robust_outlier_detection() {
609        let x = array![[1.0], [2.0], [3.0], [100.0]]; // Last point is outlier
610        let y = array![1.0, 2.0, 3.0, 100.0]; // Last target is outlier
611
612        let approximation = ApproximationMethod::RandomFourierFeatures {
613            n_components: 10,
614            gamma: 1.0,
615        };
616
617        let robust_krr = RobustKernelRidgeRegression::new(approximation)
618            .alpha(0.1)
619            .robust_loss(RobustLoss::Huber { delta: 1.0 });
620
621        let fitted = robust_krr.fit(&x, &y).unwrap();
622        let outlier_scores = fitted.outlier_scores().unwrap();
623
624        // Outlier should have higher score
625        assert!(outlier_scores[3] > outlier_scores[0]);
626        assert!(outlier_scores[3] > outlier_scores[1]);
627        assert!(outlier_scores[3] > outlier_scores[2]);
628    }
629
630    #[test]
631    fn test_robust_convergence() {
632        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
633        let y = array![1.0, 2.0, 3.0];
634
635        let approximation = ApproximationMethod::RandomFourierFeatures {
636            n_components: 10,
637            gamma: 1.0,
638        };
639
640        let robust_krr = RobustKernelRidgeRegression::new(approximation)
641            .alpha(0.1)
642            .robust_loss(RobustLoss::Huber { delta: 1.0 })
643            .max_iter(5) // Small number of iterations
644            .tolerance(1e-3);
645
646        let fitted = robust_krr.fit(&x, &y).unwrap();
647        let predictions = fitted.predict(&x).unwrap();
648
649        assert_eq!(predictions.len(), 3);
650        // Should converge even with few iterations for this simple case
651        for pred in predictions.iter() {
652            assert!(pred.is_finite());
653        }
654    }
655
656    #[test]
657    fn test_robust_reproducibility() {
658        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
659        let y = array![1.0, 2.0, 3.0];
660
661        let approximation = ApproximationMethod::RandomFourierFeatures {
662            n_components: 10,
663            gamma: 1.0,
664        };
665
666        let robust_krr1 = RobustKernelRidgeRegression::new(approximation.clone())
667            .alpha(0.1)
668            .robust_loss(RobustLoss::Huber { delta: 1.0 })
669            .random_state(42);
670        let fitted1 = robust_krr1.fit(&x, &y).unwrap();
671        let pred1 = fitted1.predict(&x).unwrap();
672
673        let robust_krr2 = RobustKernelRidgeRegression::new(approximation)
674            .alpha(0.1)
675            .robust_loss(RobustLoss::Huber { delta: 1.0 })
676            .random_state(42);
677        let fitted2 = robust_krr2.fit(&x, &y).unwrap();
678        let pred2 = fitted2.predict(&x).unwrap();
679
680        assert_eq!(pred1.len(), pred2.len());
681        for i in 0..pred1.len() {
682            assert!((pred1[i] - pred2[i]).abs() < 1e-10);
683        }
684    }
685
686    #[test]
687    fn test_robust_loss_edge_cases() {
688        let losses = vec![
689            RobustLoss::Huber { delta: 1.0 },
690            RobustLoss::EpsilonInsensitive { epsilon: 0.1 },
691            RobustLoss::Quantile { tau: 0.5 },
692            RobustLoss::Tukey { c: 4.685 },
693            RobustLoss::Cauchy { sigma: 1.0 },
694            RobustLoss::Logistic { scale: 1.0 },
695            RobustLoss::Fair { c: 1.0 },
696            RobustLoss::Welsch { c: 1.0 },
697        ];
698
699        for loss in losses {
700            // Test with zero residual
701            let loss_zero = loss.loss(0.0);
702            let weight_zero = loss.weight(0.0);
703
704            assert!(loss_zero >= 0.0);
705            assert!(weight_zero >= 0.0);
706            assert!(weight_zero <= 1.5); // Most weights should be <= 1.0, allowing some tolerance
707
708            // Test with non-zero residual
709            let loss_nonzero = loss.loss(1.0);
710            let weight_nonzero = loss.weight(1.0);
711
712            assert!(loss_nonzero >= 0.0);
713            assert!(weight_nonzero >= 0.0);
714        }
715    }
716
717    #[test]
718    fn test_custom_robust_loss() {
719        let custom_loss = RobustLoss::Custom {
720            loss_fn: |r| r * r, // Simple quadratic loss
721            weight_fn: |_| 1.0, // Constant weight
722        };
723
724        assert_eq!(custom_loss.loss(2.0), 4.0);
725        assert_eq!(custom_loss.weight(5.0), 1.0);
726    }
727}