1use crate::{
8 FastfoodTransform, Nystroem, RBFSampler, StructuredRandomFeatures, Trained, Untrained,
9};
10use scirs2_linalg::compat::ArrayLinalgExt;
11use scirs2_core::ndarray::{Array1, Array2};
14use sklears_core::error::{Result, SklearsError};
15use sklears_core::prelude::{Estimator, Fit, Float, Predict};
16use std::marker::PhantomData;
17
18use super::core_types::*;
19
20#[derive(Debug, Clone)]
40pub struct RobustKernelRidgeRegression<State = Untrained> {
41 pub approximation_method: ApproximationMethod,
42 pub alpha: Float,
43 pub robust_loss: RobustLoss,
44 pub solver: Solver,
45 pub max_iter: usize,
46 pub tolerance: Float,
47 pub random_state: Option<u64>,
48
49 weights_: Option<Array1<Float>>,
51 feature_transformer_: Option<FeatureTransformer>,
52 sample_weights_: Option<Array1<Float>>, _state: PhantomData<State>,
55}
56
57#[derive(Debug, Clone)]
59pub enum RobustLoss {
60 Huber { delta: Float },
62 EpsilonInsensitive { epsilon: Float },
64 Quantile { tau: Float },
66 Tukey { c: Float },
68 Cauchy { sigma: Float },
70 Logistic { scale: Float },
72 Fair { c: Float },
74 Welsch { c: Float },
76 Custom {
78 loss_fn: fn(Float) -> Float,
79 weight_fn: fn(Float) -> Float,
80 },
81}
82
83impl Default for RobustLoss {
84 fn default() -> Self {
85 Self::Huber { delta: 1.0 }
86 }
87}
88
89impl RobustLoss {
90 pub fn loss(&self, residual: Float) -> Float {
92 let abs_r = residual.abs();
93 match self {
94 RobustLoss::Huber { delta } => {
95 if abs_r <= *delta {
96 0.5 * residual * residual
97 } else {
98 delta * (abs_r - 0.5 * delta)
99 }
100 }
101 RobustLoss::EpsilonInsensitive { epsilon } => (abs_r - epsilon).max(0.0),
102 RobustLoss::Quantile { tau } => {
103 if residual >= 0.0 {
104 tau * residual
105 } else {
106 (tau - 1.0) * residual
107 }
108 }
109 RobustLoss::Tukey { c } => {
110 if abs_r <= *c {
111 let r_norm = residual / c;
112 (c * c / 6.0) * (1.0 - (1.0 - r_norm * r_norm).powi(3))
113 } else {
114 c * c / 6.0
115 }
116 }
117 RobustLoss::Cauchy { sigma } => {
118 (sigma * sigma / 2.0) * ((1.0 + (residual / sigma).powi(2)).ln())
119 }
120 RobustLoss::Logistic { scale } => scale * (1.0 + (-abs_r / scale).exp()).ln(),
121 RobustLoss::Fair { c } => c * (abs_r / c - (1.0 + abs_r / c).ln()),
122 RobustLoss::Welsch { c } => (c * c / 2.0) * (1.0 - (-((residual / c).powi(2))).exp()),
123 RobustLoss::Custom { loss_fn, .. } => loss_fn(residual),
124 }
125 }
126
127 pub fn weight(&self, residual: Float) -> Float {
129 let abs_r = residual.abs();
130 if abs_r < 1e-10 {
131 return 1.0; }
133
134 match self {
135 RobustLoss::Huber { delta } => {
136 if abs_r <= *delta {
137 1.0
138 } else {
139 delta / abs_r
140 }
141 }
142 RobustLoss::EpsilonInsensitive { epsilon } => {
143 if abs_r <= *epsilon {
144 0.0
145 } else {
146 1.0
147 }
148 }
149 RobustLoss::Quantile { tau } => {
150 if residual >= 0.0 {
152 *tau
153 } else {
154 1.0 - tau
155 }
156 }
157 RobustLoss::Tukey { c } => {
158 if abs_r <= *c {
159 let r_norm = residual / c;
160 (1.0 - r_norm * r_norm).powi(2)
161 } else {
162 0.0
163 }
164 }
165 RobustLoss::Cauchy { sigma } => 1.0 / (1.0 + (residual / sigma).powi(2)),
166 RobustLoss::Logistic { scale } => {
167 let exp_term = (-abs_r / scale).exp();
168 exp_term / (1.0 + exp_term)
169 }
170 RobustLoss::Fair { c } => 1.0 / (1.0 + abs_r / c),
171 RobustLoss::Welsch { c } => (-((residual / c).powi(2))).exp(),
172 RobustLoss::Custom { weight_fn, .. } => weight_fn(residual),
173 }
174 }
175}
176
177impl RobustKernelRidgeRegression<Untrained> {
178 pub fn new(approximation_method: ApproximationMethod) -> Self {
180 Self {
181 approximation_method,
182 alpha: 1.0,
183 robust_loss: RobustLoss::default(),
184 solver: Solver::Direct,
185 max_iter: 100,
186 tolerance: 1e-6,
187 random_state: None,
188 weights_: None,
189 feature_transformer_: None,
190 sample_weights_: None,
191 _state: PhantomData,
192 }
193 }
194
195 pub fn alpha(mut self, alpha: Float) -> Self {
197 self.alpha = alpha;
198 self
199 }
200
201 pub fn robust_loss(mut self, robust_loss: RobustLoss) -> Self {
203 self.robust_loss = robust_loss;
204 self
205 }
206
207 pub fn solver(mut self, solver: Solver) -> Self {
209 self.solver = solver;
210 self
211 }
212
213 pub fn max_iter(mut self, max_iter: usize) -> Self {
215 self.max_iter = max_iter;
216 self
217 }
218
219 pub fn tolerance(mut self, tolerance: Float) -> Self {
221 self.tolerance = tolerance;
222 self
223 }
224
225 pub fn random_state(mut self, seed: u64) -> Self {
227 self.random_state = Some(seed);
228 self
229 }
230}
231
232impl Estimator for RobustKernelRidgeRegression<Untrained> {
233 type Config = ();
234 type Error = SklearsError;
235 type Float = Float;
236
237 fn config(&self) -> &Self::Config {
238 &()
239 }
240}
241
242impl Fit<Array2<Float>, Array1<Float>> for RobustKernelRidgeRegression<Untrained> {
243 type Fitted = RobustKernelRidgeRegression<Trained>;
244
245 fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
246 if x.nrows() != y.len() {
247 return Err(SklearsError::InvalidInput(
248 "Number of samples must match".to_string(),
249 ));
250 }
251
252 let feature_transformer = self.fit_feature_transformer(x)?;
254 let x_transformed = feature_transformer.transform(x)?;
255
256 let (weights, sample_weights) = self.solve_robust_regression(&x_transformed, y)?;
258
259 Ok(RobustKernelRidgeRegression {
260 approximation_method: self.approximation_method,
261 alpha: self.alpha,
262 robust_loss: self.robust_loss,
263 solver: self.solver,
264 max_iter: self.max_iter,
265 tolerance: self.tolerance,
266 random_state: self.random_state,
267 weights_: Some(weights),
268 feature_transformer_: Some(feature_transformer),
269 sample_weights_: Some(sample_weights),
270 _state: PhantomData,
271 })
272 }
273}
274
275impl RobustKernelRidgeRegression<Untrained> {
276 fn fit_feature_transformer(&self, x: &Array2<Float>) -> Result<FeatureTransformer> {
278 match &self.approximation_method {
279 ApproximationMethod::Nystroem {
280 kernel,
281 n_components,
282 sampling_strategy,
283 } => {
284 let mut nystroem = Nystroem::new(kernel.clone(), *n_components)
285 .sampling_strategy(sampling_strategy.clone());
286 if let Some(seed) = self.random_state {
287 nystroem = nystroem.random_state(seed);
288 }
289 let fitted = nystroem.fit(x, &())?;
290 Ok(FeatureTransformer::Nystroem(fitted))
291 }
292 ApproximationMethod::RandomFourierFeatures {
293 n_components,
294 gamma,
295 } => {
296 let mut rff = RBFSampler::new(*n_components).gamma(*gamma);
297 if let Some(seed) = self.random_state {
298 rff = rff.random_state(seed);
299 }
300 let fitted = rff.fit(x, &())?;
301 Ok(FeatureTransformer::RBFSampler(fitted))
302 }
303 ApproximationMethod::StructuredRandomFeatures {
304 n_components,
305 gamma,
306 } => {
307 let mut srf = StructuredRandomFeatures::new(*n_components).gamma(*gamma);
308 if let Some(seed) = self.random_state {
309 srf = srf.random_state(seed);
310 }
311 let fitted = srf.fit(x, &())?;
312 Ok(FeatureTransformer::StructuredRFF(fitted))
313 }
314 ApproximationMethod::Fastfood {
315 n_components,
316 gamma,
317 } => {
318 let mut fastfood = FastfoodTransform::new(*n_components).gamma(*gamma);
319 if let Some(seed) = self.random_state {
320 fastfood = fastfood.random_state(seed);
321 }
322 let fitted = fastfood.fit(x, &())?;
323 Ok(FeatureTransformer::Fastfood(fitted))
324 }
325 }
326 }
327
328 fn solve_robust_regression(
330 &self,
331 x: &Array2<Float>,
332 y: &Array1<Float>,
333 ) -> Result<(Array1<Float>, Array1<Float>)> {
334 let n_samples = x.nrows();
335 let n_features = x.ncols();
336
337 let x_f64 = Array2::from_shape_fn(x.dim(), |(i, j)| x[[i, j]]);
339 let y_f64 = Array1::from_vec(y.iter().copied().collect());
340
341 let xtx = x_f64.t().dot(&x_f64);
342 let regularized_xtx = xtx + Array2::<f64>::eye(n_features) * self.alpha;
343 let xty = x_f64.t().dot(&y_f64);
344 let mut weights_f64 =
345 regularized_xtx
346 .solve(&xty)
347 .map_err(|e| SklearsError::InvalidParameter {
348 name: "regularization".to_string(),
349 reason: format!("Initial linear system solving failed: {:?}", e),
350 })?;
351
352 let mut sample_weights = Array1::ones(n_samples);
353 let mut prev_weights = weights_f64.clone();
354
355 for _iter in 0..self.max_iter {
357 let predictions = x_f64.dot(&weights_f64);
359 let residuals = &y_f64 - &predictions;
360
361 for (i, &residual) in residuals.iter().enumerate() {
363 sample_weights[i] = self.robust_loss.weight(residual as Float);
364 }
365
366 let mut weighted_xtx = Array2::zeros((n_features, n_features));
368 let mut weighted_xty = Array1::zeros(n_features);
369
370 for i in 0..n_samples {
371 let weight = sample_weights[i];
372 let x_row = x_f64.row(i);
373
374 for j in 0..n_features {
376 for k in 0..n_features {
377 weighted_xtx[[j, k]] += weight * x_row[j] * x_row[k];
378 }
379 }
380
381 for j in 0..n_features {
383 weighted_xty[j] += weight * x_row[j] * y_f64[i];
384 }
385 }
386
387 weighted_xtx += &(Array2::eye(n_features) * self.alpha);
389
390 weights_f64 = match self.solver {
392 Solver::Direct => weighted_xtx.solve(&weighted_xty).map_err(|e| {
393 SklearsError::InvalidParameter {
394 name: "weighted_system".to_string(),
395 reason: format!("Weighted linear system solving failed: {:?}", e),
396 }
397 })?,
398 Solver::SVD => {
399 let (u, s, vt) =
400 weighted_xtx
401 .svd(true)
402 .map_err(|e| SklearsError::InvalidParameter {
403 name: "svd".to_string(),
404 reason: format!("SVD decomposition failed: {:?}", e),
405 })?;
406 let ut_b = u.t().dot(&weighted_xty);
407 let s_inv = s.mapv(|x| if x > 1e-10 { 1.0 / x } else { 0.0 });
408 let y_svd = ut_b * s_inv;
409 vt.t().dot(&y_svd)
410 }
411 Solver::ConjugateGradient { max_iter, tol } => {
412 self.solve_cg_weighted(&weighted_xtx, &weighted_xty, max_iter, tol)?
413 }
414 };
415
416 let weight_change = (&weights_f64 - &prev_weights).mapv(|x| x.abs()).sum();
418 if weight_change < self.tolerance {
419 break;
420 }
421
422 prev_weights = weights_f64.clone();
423 }
424
425 let weights = Array1::from_vec(weights_f64.iter().map(|&val| val as Float).collect());
427 let sample_weights_float =
428 Array1::from_vec(sample_weights.iter().map(|&val| val as Float).collect());
429
430 Ok((weights, sample_weights_float))
431 }
432
433 fn solve_cg_weighted(
435 &self,
436 a: &Array2<f64>,
437 b: &Array1<f64>,
438 max_iter: usize,
439 tol: f64,
440 ) -> Result<Array1<f64>> {
441 let n = b.len();
442 let mut x = Array1::zeros(n);
443 let mut r = b - &a.dot(&x);
444 let mut p = r.clone();
445 let mut rsold = r.dot(&r);
446
447 for _iter in 0..max_iter {
448 let ap = a.dot(&p);
449 let alpha = rsold / p.dot(&ap);
450
451 x = x + &p * alpha;
452 r = r - &ap * alpha;
453
454 let rsnew = r.dot(&r);
455
456 if rsnew.sqrt() < tol {
457 break;
458 }
459
460 let beta = rsnew / rsold;
461 p = &r + &p * beta;
462 rsold = rsnew;
463 }
464
465 Ok(x)
466 }
467}
468
469impl Predict<Array2<Float>, Array1<Float>> for RobustKernelRidgeRegression<Trained> {
470 fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
471 let feature_transformer =
472 self.feature_transformer_
473 .as_ref()
474 .ok_or_else(|| SklearsError::NotFitted {
475 operation: "predict".to_string(),
476 })?;
477
478 let weights = self
479 .weights_
480 .as_ref()
481 .ok_or_else(|| SklearsError::NotFitted {
482 operation: "predict".to_string(),
483 })?;
484
485 let x_transformed = feature_transformer.transform(x)?;
486 let predictions = x_transformed.dot(weights);
487
488 Ok(predictions)
489 }
490}
491
492impl RobustKernelRidgeRegression<Trained> {
493 pub fn weights(&self) -> Option<&Array1<Float>> {
495 self.weights_.as_ref()
496 }
497
498 pub fn sample_weights(&self) -> Option<&Array1<Float>> {
500 self.sample_weights_.as_ref()
501 }
502
503 pub fn robust_residuals(
505 &self,
506 x: &Array2<Float>,
507 y: &Array1<Float>,
508 ) -> Result<(Array1<Float>, Array1<Float>)> {
509 let predictions = self.predict(x)?;
510 let residuals = y - &predictions;
511
512 let mut weights = Array1::zeros(residuals.len());
513 for (i, &residual) in residuals.iter().enumerate() {
514 weights[i] = self.robust_loss.weight(residual);
515 }
516
517 Ok((residuals, weights))
518 }
519
520 pub fn outlier_scores(&self) -> Option<Array1<Float>> {
522 self.sample_weights_.as_ref().map(|weights| {
523 weights.mapv(|w| 1.0 - w)
525 })
526 }
527}
528
529#[allow(non_snake_case)]
530#[cfg(test)]
531mod tests {
532 use super::*;
533 use scirs2_core::ndarray::array;
534
535 #[test]
536 fn test_robust_kernel_ridge_regression() {
537 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [10.0, 10.0]]; let y = array![1.0, 2.0, 3.0, 100.0]; let approximation = ApproximationMethod::RandomFourierFeatures {
541 n_components: 20,
542 gamma: 0.1,
543 };
544
545 let robust_krr = RobustKernelRidgeRegression::new(approximation)
546 .alpha(0.1)
547 .robust_loss(RobustLoss::Huber { delta: 1.0 });
548
549 let fitted = robust_krr.fit(&x, &y).unwrap();
550 let predictions = fitted.predict(&x).unwrap();
551
552 assert_eq!(predictions.len(), 4);
553
554 for pred in predictions.iter() {
556 assert!(pred.is_finite());
557 }
558
559 let sample_weights = fitted.sample_weights().unwrap();
561 assert!(sample_weights[3] < sample_weights[0]); }
563
564 #[test]
565 fn test_different_robust_losses() {
566 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
567 let y = array![1.0, 2.0, 3.0];
568
569 let approximation = ApproximationMethod::RandomFourierFeatures {
570 n_components: 10,
571 gamma: 1.0,
572 };
573
574 let loss_functions = vec![
575 RobustLoss::Huber { delta: 1.0 },
576 RobustLoss::EpsilonInsensitive { epsilon: 0.1 },
577 RobustLoss::Quantile { tau: 0.5 },
578 RobustLoss::Tukey { c: 4.685 },
579 RobustLoss::Cauchy { sigma: 1.0 },
580 ];
581
582 for loss in loss_functions {
583 let robust_krr = RobustKernelRidgeRegression::new(approximation.clone())
584 .alpha(0.1)
585 .robust_loss(loss);
586
587 let fitted = robust_krr.fit(&x, &y).unwrap();
588 let predictions = fitted.predict(&x).unwrap();
589
590 assert_eq!(predictions.len(), 3);
591 }
592 }
593
594 #[test]
595 fn test_robust_loss_functions() {
596 let huber = RobustLoss::Huber { delta: 1.0 };
597
598 assert_eq!(huber.loss(0.5), 0.125); assert_eq!(huber.loss(2.0), 1.5); assert_eq!(huber.weight(0.5), 1.0); assert_eq!(huber.weight(2.0), 0.5); }
606
607 #[test]
608 fn test_robust_outlier_detection() {
609 let x = array![[1.0], [2.0], [3.0], [100.0]]; let y = array![1.0, 2.0, 3.0, 100.0]; let approximation = ApproximationMethod::RandomFourierFeatures {
613 n_components: 10,
614 gamma: 1.0,
615 };
616
617 let robust_krr = RobustKernelRidgeRegression::new(approximation)
618 .alpha(0.1)
619 .robust_loss(RobustLoss::Huber { delta: 1.0 });
620
621 let fitted = robust_krr.fit(&x, &y).unwrap();
622 let outlier_scores = fitted.outlier_scores().unwrap();
623
624 assert!(outlier_scores[3] > outlier_scores[0]);
626 assert!(outlier_scores[3] > outlier_scores[1]);
627 assert!(outlier_scores[3] > outlier_scores[2]);
628 }
629
630 #[test]
631 fn test_robust_convergence() {
632 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
633 let y = array![1.0, 2.0, 3.0];
634
635 let approximation = ApproximationMethod::RandomFourierFeatures {
636 n_components: 10,
637 gamma: 1.0,
638 };
639
640 let robust_krr = RobustKernelRidgeRegression::new(approximation)
641 .alpha(0.1)
642 .robust_loss(RobustLoss::Huber { delta: 1.0 })
643 .max_iter(5) .tolerance(1e-3);
645
646 let fitted = robust_krr.fit(&x, &y).unwrap();
647 let predictions = fitted.predict(&x).unwrap();
648
649 assert_eq!(predictions.len(), 3);
650 for pred in predictions.iter() {
652 assert!(pred.is_finite());
653 }
654 }
655
656 #[test]
657 fn test_robust_reproducibility() {
658 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
659 let y = array![1.0, 2.0, 3.0];
660
661 let approximation = ApproximationMethod::RandomFourierFeatures {
662 n_components: 10,
663 gamma: 1.0,
664 };
665
666 let robust_krr1 = RobustKernelRidgeRegression::new(approximation.clone())
667 .alpha(0.1)
668 .robust_loss(RobustLoss::Huber { delta: 1.0 })
669 .random_state(42);
670 let fitted1 = robust_krr1.fit(&x, &y).unwrap();
671 let pred1 = fitted1.predict(&x).unwrap();
672
673 let robust_krr2 = RobustKernelRidgeRegression::new(approximation)
674 .alpha(0.1)
675 .robust_loss(RobustLoss::Huber { delta: 1.0 })
676 .random_state(42);
677 let fitted2 = robust_krr2.fit(&x, &y).unwrap();
678 let pred2 = fitted2.predict(&x).unwrap();
679
680 assert_eq!(pred1.len(), pred2.len());
681 for i in 0..pred1.len() {
682 assert!((pred1[i] - pred2[i]).abs() < 1e-10);
683 }
684 }
685
686 #[test]
687 fn test_robust_loss_edge_cases() {
688 let losses = vec![
689 RobustLoss::Huber { delta: 1.0 },
690 RobustLoss::EpsilonInsensitive { epsilon: 0.1 },
691 RobustLoss::Quantile { tau: 0.5 },
692 RobustLoss::Tukey { c: 4.685 },
693 RobustLoss::Cauchy { sigma: 1.0 },
694 RobustLoss::Logistic { scale: 1.0 },
695 RobustLoss::Fair { c: 1.0 },
696 RobustLoss::Welsch { c: 1.0 },
697 ];
698
699 for loss in losses {
700 let loss_zero = loss.loss(0.0);
702 let weight_zero = loss.weight(0.0);
703
704 assert!(loss_zero >= 0.0);
705 assert!(weight_zero >= 0.0);
706 assert!(weight_zero <= 1.5); let loss_nonzero = loss.loss(1.0);
710 let weight_nonzero = loss.weight(1.0);
711
712 assert!(loss_nonzero >= 0.0);
713 assert!(weight_nonzero >= 0.0);
714 }
715 }
716
717 #[test]
718 fn test_custom_robust_loss() {
719 let custom_loss = RobustLoss::Custom {
720 loss_fn: |r| r * r, weight_fn: |_| 1.0, };
723
724 assert_eq!(custom_loss.loss(2.0), 4.0);
725 assert_eq!(custom_loss.weight(5.0), 1.0);
726 }
727}