1use crate::{
8 FastfoodTransform, Nystroem, RBFSampler, StructuredRandomFeatures, Trained, Untrained,
9};
10use scirs2_core::ndarray::ndarray_linalg::solve::Solve;
11use scirs2_core::ndarray::ndarray_linalg::SVD;
12use scirs2_core::ndarray::{Array1, Array2};
13use sklears_core::error::{Result, SklearsError};
14use sklears_core::prelude::{Estimator, Fit, Float, Predict};
15use std::marker::PhantomData;
16
17use super::core_types::*;
18
19#[derive(Debug, Clone)]
39pub struct RobustKernelRidgeRegression<State = Untrained> {
40 pub approximation_method: ApproximationMethod,
41 pub alpha: Float,
42 pub robust_loss: RobustLoss,
43 pub solver: Solver,
44 pub max_iter: usize,
45 pub tolerance: Float,
46 pub random_state: Option<u64>,
47
48 weights_: Option<Array1<Float>>,
50 feature_transformer_: Option<FeatureTransformer>,
51 sample_weights_: Option<Array1<Float>>, _state: PhantomData<State>,
54}
55
56#[derive(Debug, Clone)]
58pub enum RobustLoss {
59 Huber { delta: Float },
61 EpsilonInsensitive { epsilon: Float },
63 Quantile { tau: Float },
65 Tukey { c: Float },
67 Cauchy { sigma: Float },
69 Logistic { scale: Float },
71 Fair { c: Float },
73 Welsch { c: Float },
75 Custom {
77 loss_fn: fn(Float) -> Float,
78 weight_fn: fn(Float) -> Float,
79 },
80}
81
82impl Default for RobustLoss {
83 fn default() -> Self {
84 Self::Huber { delta: 1.0 }
85 }
86}
87
88impl RobustLoss {
89 pub fn loss(&self, residual: Float) -> Float {
91 let abs_r = residual.abs();
92 match self {
93 RobustLoss::Huber { delta } => {
94 if abs_r <= *delta {
95 0.5 * residual * residual
96 } else {
97 delta * (abs_r - 0.5 * delta)
98 }
99 }
100 RobustLoss::EpsilonInsensitive { epsilon } => (abs_r - epsilon).max(0.0),
101 RobustLoss::Quantile { tau } => {
102 if residual >= 0.0 {
103 tau * residual
104 } else {
105 (tau - 1.0) * residual
106 }
107 }
108 RobustLoss::Tukey { c } => {
109 if abs_r <= *c {
110 let r_norm = residual / c;
111 (c * c / 6.0) * (1.0 - (1.0 - r_norm * r_norm).powi(3))
112 } else {
113 c * c / 6.0
114 }
115 }
116 RobustLoss::Cauchy { sigma } => {
117 (sigma * sigma / 2.0) * ((1.0 + (residual / sigma).powi(2)).ln())
118 }
119 RobustLoss::Logistic { scale } => scale * (1.0 + (-abs_r / scale).exp()).ln(),
120 RobustLoss::Fair { c } => c * (abs_r / c - (1.0 + abs_r / c).ln()),
121 RobustLoss::Welsch { c } => (c * c / 2.0) * (1.0 - (-((residual / c).powi(2))).exp()),
122 RobustLoss::Custom { loss_fn, .. } => loss_fn(residual),
123 }
124 }
125
126 pub fn weight(&self, residual: Float) -> Float {
128 let abs_r = residual.abs();
129 if abs_r < 1e-10 {
130 return 1.0; }
132
133 match self {
134 RobustLoss::Huber { delta } => {
135 if abs_r <= *delta {
136 1.0
137 } else {
138 delta / abs_r
139 }
140 }
141 RobustLoss::EpsilonInsensitive { epsilon } => {
142 if abs_r <= *epsilon {
143 0.0
144 } else {
145 1.0
146 }
147 }
148 RobustLoss::Quantile { tau } => {
149 if residual >= 0.0 {
151 *tau
152 } else {
153 1.0 - tau
154 }
155 }
156 RobustLoss::Tukey { c } => {
157 if abs_r <= *c {
158 let r_norm = residual / c;
159 (1.0 - r_norm * r_norm).powi(2)
160 } else {
161 0.0
162 }
163 }
164 RobustLoss::Cauchy { sigma } => 1.0 / (1.0 + (residual / sigma).powi(2)),
165 RobustLoss::Logistic { scale } => {
166 let exp_term = (-abs_r / scale).exp();
167 exp_term / (1.0 + exp_term)
168 }
169 RobustLoss::Fair { c } => 1.0 / (1.0 + abs_r / c),
170 RobustLoss::Welsch { c } => (-((residual / c).powi(2))).exp(),
171 RobustLoss::Custom { weight_fn, .. } => weight_fn(residual),
172 }
173 }
174}
175
176impl RobustKernelRidgeRegression<Untrained> {
177 pub fn new(approximation_method: ApproximationMethod) -> Self {
179 Self {
180 approximation_method,
181 alpha: 1.0,
182 robust_loss: RobustLoss::default(),
183 solver: Solver::Direct,
184 max_iter: 100,
185 tolerance: 1e-6,
186 random_state: None,
187 weights_: None,
188 feature_transformer_: None,
189 sample_weights_: None,
190 _state: PhantomData,
191 }
192 }
193
194 pub fn alpha(mut self, alpha: Float) -> Self {
196 self.alpha = alpha;
197 self
198 }
199
200 pub fn robust_loss(mut self, robust_loss: RobustLoss) -> Self {
202 self.robust_loss = robust_loss;
203 self
204 }
205
206 pub fn solver(mut self, solver: Solver) -> Self {
208 self.solver = solver;
209 self
210 }
211
212 pub fn max_iter(mut self, max_iter: usize) -> Self {
214 self.max_iter = max_iter;
215 self
216 }
217
218 pub fn tolerance(mut self, tolerance: Float) -> Self {
220 self.tolerance = tolerance;
221 self
222 }
223
224 pub fn random_state(mut self, seed: u64) -> Self {
226 self.random_state = Some(seed);
227 self
228 }
229}
230
231impl Estimator for RobustKernelRidgeRegression<Untrained> {
232 type Config = ();
233 type Error = SklearsError;
234 type Float = Float;
235
236 fn config(&self) -> &Self::Config {
237 &()
238 }
239}
240
241impl Fit<Array2<Float>, Array1<Float>> for RobustKernelRidgeRegression<Untrained> {
242 type Fitted = RobustKernelRidgeRegression<Trained>;
243
244 fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
245 if x.nrows() != y.len() {
246 return Err(SklearsError::InvalidInput(
247 "Number of samples must match".to_string(),
248 ));
249 }
250
251 let feature_transformer = self.fit_feature_transformer(x)?;
253 let x_transformed = feature_transformer.transform(x)?;
254
255 let (weights, sample_weights) = self.solve_robust_regression(&x_transformed, y)?;
257
258 Ok(RobustKernelRidgeRegression {
259 approximation_method: self.approximation_method,
260 alpha: self.alpha,
261 robust_loss: self.robust_loss,
262 solver: self.solver,
263 max_iter: self.max_iter,
264 tolerance: self.tolerance,
265 random_state: self.random_state,
266 weights_: Some(weights),
267 feature_transformer_: Some(feature_transformer),
268 sample_weights_: Some(sample_weights),
269 _state: PhantomData,
270 })
271 }
272}
273
274impl RobustKernelRidgeRegression<Untrained> {
275 fn fit_feature_transformer(&self, x: &Array2<Float>) -> Result<FeatureTransformer> {
277 match &self.approximation_method {
278 ApproximationMethod::Nystroem {
279 kernel,
280 n_components,
281 sampling_strategy,
282 } => {
283 let mut nystroem = Nystroem::new(kernel.clone(), *n_components)
284 .sampling_strategy(sampling_strategy.clone());
285 if let Some(seed) = self.random_state {
286 nystroem = nystroem.random_state(seed);
287 }
288 let fitted = nystroem.fit(x, &())?;
289 Ok(FeatureTransformer::Nystroem(fitted))
290 }
291 ApproximationMethod::RandomFourierFeatures {
292 n_components,
293 gamma,
294 } => {
295 let mut rff = RBFSampler::new(*n_components).gamma(*gamma);
296 if let Some(seed) = self.random_state {
297 rff = rff.random_state(seed);
298 }
299 let fitted = rff.fit(x, &())?;
300 Ok(FeatureTransformer::RBFSampler(fitted))
301 }
302 ApproximationMethod::StructuredRandomFeatures {
303 n_components,
304 gamma,
305 } => {
306 let mut srf = StructuredRandomFeatures::new(*n_components).gamma(*gamma);
307 if let Some(seed) = self.random_state {
308 srf = srf.random_state(seed);
309 }
310 let fitted = srf.fit(x, &())?;
311 Ok(FeatureTransformer::StructuredRFF(fitted))
312 }
313 ApproximationMethod::Fastfood {
314 n_components,
315 gamma,
316 } => {
317 let mut fastfood = FastfoodTransform::new(*n_components).gamma(*gamma);
318 if let Some(seed) = self.random_state {
319 fastfood = fastfood.random_state(seed);
320 }
321 let fitted = fastfood.fit(x, &())?;
322 Ok(FeatureTransformer::Fastfood(fitted))
323 }
324 }
325 }
326
327 fn solve_robust_regression(
329 &self,
330 x: &Array2<Float>,
331 y: &Array1<Float>,
332 ) -> Result<(Array1<Float>, Array1<Float>)> {
333 let n_samples = x.nrows();
334 let n_features = x.ncols();
335
336 let x_f64 = Array2::from_shape_fn(x.dim(), |(i, j)| x[[i, j]]);
338 let y_f64 = Array1::from_vec(y.iter().copied().collect());
339
340 let xtx = x_f64.t().dot(&x_f64);
341 let regularized_xtx = xtx + Array2::<f64>::eye(n_features) * self.alpha;
342 let xty = x_f64.t().dot(&y_f64);
343 let mut weights_f64 =
344 regularized_xtx
345 .solve(&xty)
346 .map_err(|e| SklearsError::InvalidParameter {
347 name: "regularization".to_string(),
348 reason: format!("Initial linear system solving failed: {:?}", e),
349 })?;
350
351 let mut sample_weights = Array1::ones(n_samples);
352 let mut prev_weights = weights_f64.clone();
353
354 for _iter in 0..self.max_iter {
356 let predictions = x_f64.dot(&weights_f64);
358 let residuals = &y_f64 - &predictions;
359
360 for (i, &residual) in residuals.iter().enumerate() {
362 sample_weights[i] = self.robust_loss.weight(residual as Float);
363 }
364
365 let mut weighted_xtx = Array2::zeros((n_features, n_features));
367 let mut weighted_xty = Array1::zeros(n_features);
368
369 for i in 0..n_samples {
370 let weight = sample_weights[i];
371 let x_row = x_f64.row(i);
372
373 for j in 0..n_features {
375 for k in 0..n_features {
376 weighted_xtx[[j, k]] += weight * x_row[j] * x_row[k];
377 }
378 }
379
380 for j in 0..n_features {
382 weighted_xty[j] += weight * x_row[j] * y_f64[i];
383 }
384 }
385
386 weighted_xtx += &(Array2::eye(n_features) * self.alpha);
388
389 weights_f64 = match self.solver {
391 Solver::Direct => weighted_xtx.solve(&weighted_xty).map_err(|e| {
392 SklearsError::InvalidParameter {
393 name: "weighted_system".to_string(),
394 reason: format!("Weighted linear system solving failed: {:?}", e),
395 }
396 })?,
397 Solver::SVD => {
398 let (u, s, vt) = weighted_xtx.svd(true, true).map_err(|e| {
399 SklearsError::InvalidParameter {
400 name: "svd".to_string(),
401 reason: format!("SVD decomposition failed: {:?}", e),
402 }
403 })?;
404 let u = u.unwrap();
405 let vt = vt.unwrap();
406 let ut_b = u.t().dot(&weighted_xty);
407 let s_inv = s.mapv(|x| if x > 1e-10 { 1.0 / x } else { 0.0 });
408 let y_svd = ut_b * s_inv;
409 vt.t().dot(&y_svd)
410 }
411 Solver::ConjugateGradient { max_iter, tol } => {
412 self.solve_cg_weighted(&weighted_xtx, &weighted_xty, max_iter, tol)?
413 }
414 };
415
416 let weight_change = (&weights_f64 - &prev_weights).mapv(|x| x.abs()).sum();
418 if weight_change < self.tolerance {
419 break;
420 }
421
422 prev_weights = weights_f64.clone();
423 }
424
425 let weights = Array1::from_vec(weights_f64.iter().map(|&val| val as Float).collect());
427 let sample_weights_float =
428 Array1::from_vec(sample_weights.iter().map(|&val| val as Float).collect());
429
430 Ok((weights, sample_weights_float))
431 }
432
433 fn solve_cg_weighted(
435 &self,
436 a: &Array2<f64>,
437 b: &Array1<f64>,
438 max_iter: usize,
439 tol: f64,
440 ) -> Result<Array1<f64>> {
441 let n = b.len();
442 let mut x = Array1::zeros(n);
443 let mut r = b - &a.dot(&x);
444 let mut p = r.clone();
445 let mut rsold = r.dot(&r);
446
447 for _iter in 0..max_iter {
448 let ap = a.dot(&p);
449 let alpha = rsold / p.dot(&ap);
450
451 x = x + &p * alpha;
452 r = r - &ap * alpha;
453
454 let rsnew = r.dot(&r);
455
456 if rsnew.sqrt() < tol {
457 break;
458 }
459
460 let beta = rsnew / rsold;
461 p = &r + &p * beta;
462 rsold = rsnew;
463 }
464
465 Ok(x)
466 }
467}
468
469impl Predict<Array2<Float>, Array1<Float>> for RobustKernelRidgeRegression<Trained> {
470 fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
471 let feature_transformer =
472 self.feature_transformer_
473 .as_ref()
474 .ok_or_else(|| SklearsError::NotFitted {
475 operation: "predict".to_string(),
476 })?;
477
478 let weights = self
479 .weights_
480 .as_ref()
481 .ok_or_else(|| SklearsError::NotFitted {
482 operation: "predict".to_string(),
483 })?;
484
485 let x_transformed = feature_transformer.transform(x)?;
486 let predictions = x_transformed.dot(weights);
487
488 Ok(predictions)
489 }
490}
491
492impl RobustKernelRidgeRegression<Trained> {
493 pub fn weights(&self) -> Option<&Array1<Float>> {
495 self.weights_.as_ref()
496 }
497
498 pub fn sample_weights(&self) -> Option<&Array1<Float>> {
500 self.sample_weights_.as_ref()
501 }
502
503 pub fn robust_residuals(
505 &self,
506 x: &Array2<Float>,
507 y: &Array1<Float>,
508 ) -> Result<(Array1<Float>, Array1<Float>)> {
509 let predictions = self.predict(x)?;
510 let residuals = y - &predictions;
511
512 let mut weights = Array1::zeros(residuals.len());
513 for (i, &residual) in residuals.iter().enumerate() {
514 weights[i] = self.robust_loss.weight(residual);
515 }
516
517 Ok((residuals, weights))
518 }
519
520 pub fn outlier_scores(&self) -> Option<Array1<Float>> {
522 self.sample_weights_.as_ref().map(|weights| {
523 weights.mapv(|w| 1.0 - w)
525 })
526 }
527}
528
529#[allow(non_snake_case)]
530#[cfg(test)]
531mod tests {
532 use super::*;
533 use scirs2_core::ndarray::array;
534
535 #[test]
536 fn test_robust_kernel_ridge_regression() {
537 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [10.0, 10.0]]; let y = array![1.0, 2.0, 3.0, 100.0]; let approximation = ApproximationMethod::RandomFourierFeatures {
541 n_components: 20,
542 gamma: 0.1,
543 };
544
545 let robust_krr = RobustKernelRidgeRegression::new(approximation)
546 .alpha(0.1)
547 .robust_loss(RobustLoss::Huber { delta: 1.0 });
548
549 let fitted = robust_krr.fit(&x, &y).unwrap();
550 let predictions = fitted.predict(&x).unwrap();
551
552 assert_eq!(predictions.len(), 4);
553
554 for pred in predictions.iter() {
556 assert!(pred.is_finite());
557 }
558
559 let sample_weights = fitted.sample_weights().unwrap();
561 assert!(sample_weights[3] < sample_weights[0]); }
563
564 #[test]
565 fn test_different_robust_losses() {
566 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
567 let y = array![1.0, 2.0, 3.0];
568
569 let approximation = ApproximationMethod::RandomFourierFeatures {
570 n_components: 10,
571 gamma: 1.0,
572 };
573
574 let loss_functions = vec![
575 RobustLoss::Huber { delta: 1.0 },
576 RobustLoss::EpsilonInsensitive { epsilon: 0.1 },
577 RobustLoss::Quantile { tau: 0.5 },
578 RobustLoss::Tukey { c: 4.685 },
579 RobustLoss::Cauchy { sigma: 1.0 },
580 ];
581
582 for loss in loss_functions {
583 let robust_krr = RobustKernelRidgeRegression::new(approximation.clone())
584 .alpha(0.1)
585 .robust_loss(loss);
586
587 let fitted = robust_krr.fit(&x, &y).unwrap();
588 let predictions = fitted.predict(&x).unwrap();
589
590 assert_eq!(predictions.len(), 3);
591 }
592 }
593
594 #[test]
595 fn test_robust_loss_functions() {
596 let huber = RobustLoss::Huber { delta: 1.0 };
597
598 assert_eq!(huber.loss(0.5), 0.125); assert_eq!(huber.loss(2.0), 1.5); assert_eq!(huber.weight(0.5), 1.0); assert_eq!(huber.weight(2.0), 0.5); }
606
607 #[test]
608 fn test_robust_outlier_detection() {
609 let x = array![[1.0], [2.0], [3.0], [100.0]]; let y = array![1.0, 2.0, 3.0, 100.0]; let approximation = ApproximationMethod::RandomFourierFeatures {
613 n_components: 10,
614 gamma: 1.0,
615 };
616
617 let robust_krr = RobustKernelRidgeRegression::new(approximation)
618 .alpha(0.1)
619 .robust_loss(RobustLoss::Huber { delta: 1.0 });
620
621 let fitted = robust_krr.fit(&x, &y).unwrap();
622 let outlier_scores = fitted.outlier_scores().unwrap();
623
624 assert!(outlier_scores[3] > outlier_scores[0]);
626 assert!(outlier_scores[3] > outlier_scores[1]);
627 assert!(outlier_scores[3] > outlier_scores[2]);
628 }
629
630 #[test]
631 fn test_robust_convergence() {
632 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
633 let y = array![1.0, 2.0, 3.0];
634
635 let approximation = ApproximationMethod::RandomFourierFeatures {
636 n_components: 10,
637 gamma: 1.0,
638 };
639
640 let robust_krr = RobustKernelRidgeRegression::new(approximation)
641 .alpha(0.1)
642 .robust_loss(RobustLoss::Huber { delta: 1.0 })
643 .max_iter(5) .tolerance(1e-3);
645
646 let fitted = robust_krr.fit(&x, &y).unwrap();
647 let predictions = fitted.predict(&x).unwrap();
648
649 assert_eq!(predictions.len(), 3);
650 for pred in predictions.iter() {
652 assert!(pred.is_finite());
653 }
654 }
655
656 #[test]
657 fn test_robust_reproducibility() {
658 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
659 let y = array![1.0, 2.0, 3.0];
660
661 let approximation = ApproximationMethod::RandomFourierFeatures {
662 n_components: 10,
663 gamma: 1.0,
664 };
665
666 let robust_krr1 = RobustKernelRidgeRegression::new(approximation.clone())
667 .alpha(0.1)
668 .robust_loss(RobustLoss::Huber { delta: 1.0 })
669 .random_state(42);
670 let fitted1 = robust_krr1.fit(&x, &y).unwrap();
671 let pred1 = fitted1.predict(&x).unwrap();
672
673 let robust_krr2 = RobustKernelRidgeRegression::new(approximation)
674 .alpha(0.1)
675 .robust_loss(RobustLoss::Huber { delta: 1.0 })
676 .random_state(42);
677 let fitted2 = robust_krr2.fit(&x, &y).unwrap();
678 let pred2 = fitted2.predict(&x).unwrap();
679
680 assert_eq!(pred1.len(), pred2.len());
681 for i in 0..pred1.len() {
682 assert!((pred1[i] - pred2[i]).abs() < 1e-10);
683 }
684 }
685
686 #[test]
687 fn test_robust_loss_edge_cases() {
688 let losses = vec![
689 RobustLoss::Huber { delta: 1.0 },
690 RobustLoss::EpsilonInsensitive { epsilon: 0.1 },
691 RobustLoss::Quantile { tau: 0.5 },
692 RobustLoss::Tukey { c: 4.685 },
693 RobustLoss::Cauchy { sigma: 1.0 },
694 RobustLoss::Logistic { scale: 1.0 },
695 RobustLoss::Fair { c: 1.0 },
696 RobustLoss::Welsch { c: 1.0 },
697 ];
698
699 for loss in losses {
700 let loss_zero = loss.loss(0.0);
702 let weight_zero = loss.weight(0.0);
703
704 assert!(loss_zero >= 0.0);
705 assert!(weight_zero >= 0.0);
706 assert!(weight_zero <= 1.5); let loss_nonzero = loss.loss(1.0);
710 let weight_nonzero = loss.weight(1.0);
711
712 assert!(loss_nonzero >= 0.0);
713 assert!(weight_nonzero >= 0.0);
714 }
715 }
716
717 #[test]
718 fn test_custom_robust_loss() {
719 let custom_loss = RobustLoss::Custom {
720 loss_fn: |r| r * r, weight_fn: |_| 1.0, };
723
724 assert_eq!(custom_loss.loss(2.0), 4.0);
725 assert_eq!(custom_loss.weight(5.0), 1.0);
726 }
727}