1use crate::{
8 FastfoodTransform, Nystroem, RBFSampler, StructuredRandomFeatures, Trained, Untrained,
9};
10use scirs2_linalg::compat::ArrayLinalgExt;
11use scirs2_core::ndarray::{Array1, Array2};
14use sklears_core::error::{Result, SklearsError};
15use sklears_core::prelude::{Estimator, Fit, Float, Predict};
16use std::marker::PhantomData;
17
18use super::core_types::*;
19
20#[derive(Debug, Clone)]
35pub struct RobustKernelRidgeRegression<State = Untrained> {
36 pub approximation_method: ApproximationMethod,
37 pub alpha: Float,
38 pub robust_loss: RobustLoss,
39 pub solver: Solver,
40 pub max_iter: usize,
41 pub tolerance: Float,
42 pub random_state: Option<u64>,
43
44 weights_: Option<Array1<Float>>,
46 feature_transformer_: Option<FeatureTransformer>,
47 sample_weights_: Option<Array1<Float>>, _state: PhantomData<State>,
50}
51
52#[derive(Debug, Clone)]
54pub enum RobustLoss {
55 Huber { delta: Float },
57 EpsilonInsensitive { epsilon: Float },
59 Quantile { tau: Float },
61 Tukey { c: Float },
63 Cauchy { sigma: Float },
65 Logistic { scale: Float },
67 Fair { c: Float },
69 Welsch { c: Float },
71 Custom {
73 loss_fn: fn(Float) -> Float,
74 weight_fn: fn(Float) -> Float,
75 },
76}
77
78impl Default for RobustLoss {
79 fn default() -> Self {
80 Self::Huber { delta: 1.0 }
81 }
82}
83
84impl RobustLoss {
85 pub fn loss(&self, residual: Float) -> Float {
87 let abs_r = residual.abs();
88 match self {
89 RobustLoss::Huber { delta } => {
90 if abs_r <= *delta {
91 0.5 * residual * residual
92 } else {
93 delta * (abs_r - 0.5 * delta)
94 }
95 }
96 RobustLoss::EpsilonInsensitive { epsilon } => (abs_r - epsilon).max(0.0),
97 RobustLoss::Quantile { tau } => {
98 if residual >= 0.0 {
99 tau * residual
100 } else {
101 (tau - 1.0) * residual
102 }
103 }
104 RobustLoss::Tukey { c } => {
105 if abs_r <= *c {
106 let r_norm = residual / c;
107 (c * c / 6.0) * (1.0 - (1.0 - r_norm * r_norm).powi(3))
108 } else {
109 c * c / 6.0
110 }
111 }
112 RobustLoss::Cauchy { sigma } => {
113 (sigma * sigma / 2.0) * ((1.0 + (residual / sigma).powi(2)).ln())
114 }
115 RobustLoss::Logistic { scale } => scale * (1.0 + (-abs_r / scale).exp()).ln(),
116 RobustLoss::Fair { c } => c * (abs_r / c - (1.0 + abs_r / c).ln()),
117 RobustLoss::Welsch { c } => (c * c / 2.0) * (1.0 - (-((residual / c).powi(2))).exp()),
118 RobustLoss::Custom { loss_fn, .. } => loss_fn(residual),
119 }
120 }
121
122 pub fn weight(&self, residual: Float) -> Float {
124 let abs_r = residual.abs();
125 if abs_r < 1e-10 {
126 return 1.0; }
128
129 match self {
130 RobustLoss::Huber { delta } => {
131 if abs_r <= *delta {
132 1.0
133 } else {
134 delta / abs_r
135 }
136 }
137 RobustLoss::EpsilonInsensitive { epsilon } => {
138 if abs_r <= *epsilon {
139 0.0
140 } else {
141 1.0
142 }
143 }
144 RobustLoss::Quantile { tau } => {
145 if residual >= 0.0 {
147 *tau
148 } else {
149 1.0 - tau
150 }
151 }
152 RobustLoss::Tukey { c } => {
153 if abs_r <= *c {
154 let r_norm = residual / c;
155 (1.0 - r_norm * r_norm).powi(2)
156 } else {
157 0.0
158 }
159 }
160 RobustLoss::Cauchy { sigma } => 1.0 / (1.0 + (residual / sigma).powi(2)),
161 RobustLoss::Logistic { scale } => {
162 let exp_term = (-abs_r / scale).exp();
163 exp_term / (1.0 + exp_term)
164 }
165 RobustLoss::Fair { c } => 1.0 / (1.0 + abs_r / c),
166 RobustLoss::Welsch { c } => (-((residual / c).powi(2))).exp(),
167 RobustLoss::Custom { weight_fn, .. } => weight_fn(residual),
168 }
169 }
170}
171
172impl RobustKernelRidgeRegression<Untrained> {
173 pub fn new(approximation_method: ApproximationMethod) -> Self {
175 Self {
176 approximation_method,
177 alpha: 1.0,
178 robust_loss: RobustLoss::default(),
179 solver: Solver::Direct,
180 max_iter: 100,
181 tolerance: 1e-6,
182 random_state: None,
183 weights_: None,
184 feature_transformer_: None,
185 sample_weights_: None,
186 _state: PhantomData,
187 }
188 }
189
190 pub fn alpha(mut self, alpha: Float) -> Self {
192 self.alpha = alpha;
193 self
194 }
195
196 pub fn robust_loss(mut self, robust_loss: RobustLoss) -> Self {
198 self.robust_loss = robust_loss;
199 self
200 }
201
202 pub fn solver(mut self, solver: Solver) -> Self {
204 self.solver = solver;
205 self
206 }
207
208 pub fn max_iter(mut self, max_iter: usize) -> Self {
210 self.max_iter = max_iter;
211 self
212 }
213
214 pub fn tolerance(mut self, tolerance: Float) -> Self {
216 self.tolerance = tolerance;
217 self
218 }
219
220 pub fn random_state(mut self, seed: u64) -> Self {
222 self.random_state = Some(seed);
223 self
224 }
225}
226
227impl Estimator for RobustKernelRidgeRegression<Untrained> {
228 type Config = ();
229 type Error = SklearsError;
230 type Float = Float;
231
232 fn config(&self) -> &Self::Config {
233 &()
234 }
235}
236
237impl Fit<Array2<Float>, Array1<Float>> for RobustKernelRidgeRegression<Untrained> {
238 type Fitted = RobustKernelRidgeRegression<Trained>;
239
240 fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
241 if x.nrows() != y.len() {
242 return Err(SklearsError::InvalidInput(
243 "Number of samples must match".to_string(),
244 ));
245 }
246
247 let feature_transformer = self.fit_feature_transformer(x)?;
249 let x_transformed = feature_transformer.transform(x)?;
250
251 let (weights, sample_weights) = self.solve_robust_regression(&x_transformed, y)?;
253
254 Ok(RobustKernelRidgeRegression {
255 approximation_method: self.approximation_method,
256 alpha: self.alpha,
257 robust_loss: self.robust_loss,
258 solver: self.solver,
259 max_iter: self.max_iter,
260 tolerance: self.tolerance,
261 random_state: self.random_state,
262 weights_: Some(weights),
263 feature_transformer_: Some(feature_transformer),
264 sample_weights_: Some(sample_weights),
265 _state: PhantomData,
266 })
267 }
268}
269
270impl RobustKernelRidgeRegression<Untrained> {
271 fn fit_feature_transformer(&self, x: &Array2<Float>) -> Result<FeatureTransformer> {
273 match &self.approximation_method {
274 ApproximationMethod::Nystroem {
275 kernel,
276 n_components,
277 sampling_strategy,
278 } => {
279 let mut nystroem = Nystroem::new(kernel.clone(), *n_components)
280 .sampling_strategy(sampling_strategy.clone());
281 if let Some(seed) = self.random_state {
282 nystroem = nystroem.random_state(seed);
283 }
284 let fitted = nystroem.fit(x, &())?;
285 Ok(FeatureTransformer::Nystroem(fitted))
286 }
287 ApproximationMethod::RandomFourierFeatures {
288 n_components,
289 gamma,
290 } => {
291 let mut rff = RBFSampler::new(*n_components).gamma(*gamma);
292 if let Some(seed) = self.random_state {
293 rff = rff.random_state(seed);
294 }
295 let fitted = rff.fit(x, &())?;
296 Ok(FeatureTransformer::RBFSampler(fitted))
297 }
298 ApproximationMethod::StructuredRandomFeatures {
299 n_components,
300 gamma,
301 } => {
302 let mut srf = StructuredRandomFeatures::new(*n_components).gamma(*gamma);
303 if let Some(seed) = self.random_state {
304 srf = srf.random_state(seed);
305 }
306 let fitted = srf.fit(x, &())?;
307 Ok(FeatureTransformer::StructuredRFF(fitted))
308 }
309 ApproximationMethod::Fastfood {
310 n_components,
311 gamma,
312 } => {
313 let mut fastfood = FastfoodTransform::new(*n_components).gamma(*gamma);
314 if let Some(seed) = self.random_state {
315 fastfood = fastfood.random_state(seed);
316 }
317 let fitted = fastfood.fit(x, &())?;
318 Ok(FeatureTransformer::Fastfood(fitted))
319 }
320 }
321 }
322
323 fn solve_robust_regression(
325 &self,
326 x: &Array2<Float>,
327 y: &Array1<Float>,
328 ) -> Result<(Array1<Float>, Array1<Float>)> {
329 let n_samples = x.nrows();
330 let n_features = x.ncols();
331
332 let x_f64 = Array2::from_shape_fn(x.dim(), |(i, j)| x[[i, j]]);
334 let y_f64 = Array1::from_vec(y.iter().copied().collect());
335
336 let xtx = x_f64.t().dot(&x_f64);
337 let regularized_xtx = xtx + Array2::<f64>::eye(n_features) * self.alpha;
338 let xty = x_f64.t().dot(&y_f64);
339 let mut weights_f64 =
340 regularized_xtx
341 .solve(&xty)
342 .map_err(|e| SklearsError::InvalidParameter {
343 name: "regularization".to_string(),
344 reason: format!("Initial linear system solving failed: {:?}", e),
345 })?;
346
347 let mut sample_weights = Array1::ones(n_samples);
348 let mut prev_weights = weights_f64.clone();
349
350 for _iter in 0..self.max_iter {
352 let predictions = x_f64.dot(&weights_f64);
354 let residuals = &y_f64 - &predictions;
355
356 for (i, &residual) in residuals.iter().enumerate() {
358 sample_weights[i] = self.robust_loss.weight(residual as Float);
359 }
360
361 let mut weighted_xtx = Array2::zeros((n_features, n_features));
363 let mut weighted_xty = Array1::zeros(n_features);
364
365 for i in 0..n_samples {
366 let weight = sample_weights[i];
367 let x_row = x_f64.row(i);
368
369 for j in 0..n_features {
371 for k in 0..n_features {
372 weighted_xtx[[j, k]] += weight * x_row[j] * x_row[k];
373 }
374 }
375
376 for j in 0..n_features {
378 weighted_xty[j] += weight * x_row[j] * y_f64[i];
379 }
380 }
381
382 weighted_xtx += &(Array2::eye(n_features) * self.alpha);
384
385 weights_f64 = match self.solver {
387 Solver::Direct => weighted_xtx.solve(&weighted_xty).map_err(|e| {
388 SklearsError::InvalidParameter {
389 name: "weighted_system".to_string(),
390 reason: format!("Weighted linear system solving failed: {:?}", e),
391 }
392 })?,
393 Solver::SVD => {
394 let (u, s, vt) =
395 weighted_xtx
396 .svd(true)
397 .map_err(|e| SklearsError::InvalidParameter {
398 name: "svd".to_string(),
399 reason: format!("SVD decomposition failed: {:?}", e),
400 })?;
401 let ut_b = u.t().dot(&weighted_xty);
402 let s_inv = s.mapv(|x| if x > 1e-10 { 1.0 / x } else { 0.0 });
403 let y_svd = ut_b * s_inv;
404 vt.t().dot(&y_svd)
405 }
406 Solver::ConjugateGradient { max_iter, tol } => {
407 self.solve_cg_weighted(&weighted_xtx, &weighted_xty, max_iter, tol)?
408 }
409 };
410
411 let weight_change = (&weights_f64 - &prev_weights).mapv(|x| x.abs()).sum();
413 if weight_change < self.tolerance {
414 break;
415 }
416
417 prev_weights = weights_f64.clone();
418 }
419
420 let weights = Array1::from_vec(weights_f64.iter().map(|&val| val as Float).collect());
422 let sample_weights_float =
423 Array1::from_vec(sample_weights.iter().map(|&val| val as Float).collect());
424
425 Ok((weights, sample_weights_float))
426 }
427
428 fn solve_cg_weighted(
430 &self,
431 a: &Array2<f64>,
432 b: &Array1<f64>,
433 max_iter: usize,
434 tol: f64,
435 ) -> Result<Array1<f64>> {
436 let n = b.len();
437 let mut x = Array1::zeros(n);
438 let mut r = b - &a.dot(&x);
439 let mut p = r.clone();
440 let mut rsold = r.dot(&r);
441
442 for _iter in 0..max_iter {
443 let ap = a.dot(&p);
444 let alpha = rsold / p.dot(&ap);
445
446 x = x + &p * alpha;
447 r = r - &ap * alpha;
448
449 let rsnew = r.dot(&r);
450
451 if rsnew.sqrt() < tol {
452 break;
453 }
454
455 let beta = rsnew / rsold;
456 p = &r + &p * beta;
457 rsold = rsnew;
458 }
459
460 Ok(x)
461 }
462}
463
464impl Predict<Array2<Float>, Array1<Float>> for RobustKernelRidgeRegression<Trained> {
465 fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
466 let feature_transformer =
467 self.feature_transformer_
468 .as_ref()
469 .ok_or_else(|| SklearsError::NotFitted {
470 operation: "predict".to_string(),
471 })?;
472
473 let weights = self
474 .weights_
475 .as_ref()
476 .ok_or_else(|| SklearsError::NotFitted {
477 operation: "predict".to_string(),
478 })?;
479
480 let x_transformed = feature_transformer.transform(x)?;
481 let predictions = x_transformed.dot(weights);
482
483 Ok(predictions)
484 }
485}
486
487impl RobustKernelRidgeRegression<Trained> {
488 pub fn weights(&self) -> Option<&Array1<Float>> {
490 self.weights_.as_ref()
491 }
492
493 pub fn sample_weights(&self) -> Option<&Array1<Float>> {
495 self.sample_weights_.as_ref()
496 }
497
498 pub fn robust_residuals(
500 &self,
501 x: &Array2<Float>,
502 y: &Array1<Float>,
503 ) -> Result<(Array1<Float>, Array1<Float>)> {
504 let predictions = self.predict(x)?;
505 let residuals = y - &predictions;
506
507 let mut weights = Array1::zeros(residuals.len());
508 for (i, &residual) in residuals.iter().enumerate() {
509 weights[i] = self.robust_loss.weight(residual);
510 }
511
512 Ok((residuals, weights))
513 }
514
515 pub fn outlier_scores(&self) -> Option<Array1<Float>> {
517 self.sample_weights_.as_ref().map(|weights| {
518 weights.mapv(|w| 1.0 - w)
520 })
521 }
522}
523
524#[allow(non_snake_case)]
525#[cfg(test)]
526mod tests {
527 use super::*;
528 use scirs2_core::ndarray::array;
529
530 #[test]
531 fn test_robust_kernel_ridge_regression() {
532 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [10.0, 10.0]]; let y = array![1.0, 2.0, 3.0, 100.0]; let approximation = ApproximationMethod::RandomFourierFeatures {
536 n_components: 20,
537 gamma: 0.1,
538 };
539
540 let robust_krr = RobustKernelRidgeRegression::new(approximation)
541 .alpha(0.1)
542 .robust_loss(RobustLoss::Huber { delta: 1.0 });
543
544 let fitted = robust_krr.fit(&x, &y).expect("operation should succeed");
545 let predictions = fitted.predict(&x).expect("operation should succeed");
546
547 assert_eq!(predictions.len(), 4);
548
549 for pred in predictions.iter() {
551 assert!(pred.is_finite());
552 }
553
554 let sample_weights = fitted.sample_weights().expect("operation should succeed");
556 assert!(sample_weights[3] < sample_weights[0]); }
558
559 #[test]
560 fn test_different_robust_losses() {
561 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
562 let y = array![1.0, 2.0, 3.0];
563
564 let approximation = ApproximationMethod::RandomFourierFeatures {
565 n_components: 10,
566 gamma: 1.0,
567 };
568
569 let loss_functions = vec![
570 RobustLoss::Huber { delta: 1.0 },
571 RobustLoss::EpsilonInsensitive { epsilon: 0.1 },
572 RobustLoss::Quantile { tau: 0.5 },
573 RobustLoss::Tukey { c: 4.685 },
574 RobustLoss::Cauchy { sigma: 1.0 },
575 ];
576
577 for loss in loss_functions {
578 let robust_krr = RobustKernelRidgeRegression::new(approximation.clone())
579 .alpha(0.1)
580 .robust_loss(loss);
581
582 let fitted = robust_krr.fit(&x, &y).expect("operation should succeed");
583 let predictions = fitted.predict(&x).expect("operation should succeed");
584
585 assert_eq!(predictions.len(), 3);
586 }
587 }
588
589 #[test]
590 fn test_robust_loss_functions() {
591 let huber = RobustLoss::Huber { delta: 1.0 };
592
593 assert_eq!(huber.loss(0.5), 0.125); assert_eq!(huber.loss(2.0), 1.5); assert_eq!(huber.weight(0.5), 1.0); assert_eq!(huber.weight(2.0), 0.5); }
601
602 #[test]
603 fn test_robust_outlier_detection() {
604 let x = array![[1.0], [2.0], [3.0], [100.0]]; let y = array![1.0, 2.0, 3.0, 100.0]; let approximation = ApproximationMethod::RandomFourierFeatures {
608 n_components: 10,
609 gamma: 1.0,
610 };
611
612 let robust_krr = RobustKernelRidgeRegression::new(approximation)
613 .alpha(0.1)
614 .robust_loss(RobustLoss::Huber { delta: 1.0 });
615
616 let fitted = robust_krr.fit(&x, &y).expect("operation should succeed");
617 let outlier_scores = fitted.outlier_scores().expect("operation should succeed");
618
619 assert!(outlier_scores[3] > outlier_scores[0]);
621 assert!(outlier_scores[3] > outlier_scores[1]);
622 assert!(outlier_scores[3] > outlier_scores[2]);
623 }
624
625 #[test]
626 fn test_robust_convergence() {
627 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
628 let y = array![1.0, 2.0, 3.0];
629
630 let approximation = ApproximationMethod::RandomFourierFeatures {
631 n_components: 10,
632 gamma: 1.0,
633 };
634
635 let robust_krr = RobustKernelRidgeRegression::new(approximation)
636 .alpha(0.1)
637 .robust_loss(RobustLoss::Huber { delta: 1.0 })
638 .max_iter(5) .tolerance(1e-3);
640
641 let fitted = robust_krr.fit(&x, &y).expect("operation should succeed");
642 let predictions = fitted.predict(&x).expect("operation should succeed");
643
644 assert_eq!(predictions.len(), 3);
645 for pred in predictions.iter() {
647 assert!(pred.is_finite());
648 }
649 }
650
651 #[test]
652 fn test_robust_reproducibility() {
653 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
654 let y = array![1.0, 2.0, 3.0];
655
656 let approximation = ApproximationMethod::RandomFourierFeatures {
657 n_components: 10,
658 gamma: 1.0,
659 };
660
661 let robust_krr1 = RobustKernelRidgeRegression::new(approximation.clone())
662 .alpha(0.1)
663 .robust_loss(RobustLoss::Huber { delta: 1.0 })
664 .random_state(42);
665 let fitted1 = robust_krr1.fit(&x, &y).expect("operation should succeed");
666 let pred1 = fitted1.predict(&x).expect("operation should succeed");
667
668 let robust_krr2 = RobustKernelRidgeRegression::new(approximation)
669 .alpha(0.1)
670 .robust_loss(RobustLoss::Huber { delta: 1.0 })
671 .random_state(42);
672 let fitted2 = robust_krr2.fit(&x, &y).expect("operation should succeed");
673 let pred2 = fitted2.predict(&x).expect("operation should succeed");
674
675 assert_eq!(pred1.len(), pred2.len());
676 for i in 0..pred1.len() {
677 assert!((pred1[i] - pred2[i]).abs() < 1e-10);
678 }
679 }
680
681 #[test]
682 fn test_robust_loss_edge_cases() {
683 let losses = vec![
684 RobustLoss::Huber { delta: 1.0 },
685 RobustLoss::EpsilonInsensitive { epsilon: 0.1 },
686 RobustLoss::Quantile { tau: 0.5 },
687 RobustLoss::Tukey { c: 4.685 },
688 RobustLoss::Cauchy { sigma: 1.0 },
689 RobustLoss::Logistic { scale: 1.0 },
690 RobustLoss::Fair { c: 1.0 },
691 RobustLoss::Welsch { c: 1.0 },
692 ];
693
694 for loss in losses {
695 let loss_zero = loss.loss(0.0);
697 let weight_zero = loss.weight(0.0);
698
699 assert!(loss_zero >= 0.0);
700 assert!(weight_zero >= 0.0);
701 assert!(weight_zero <= 1.5); let loss_nonzero = loss.loss(1.0);
705 let weight_nonzero = loss.weight(1.0);
706
707 assert!(loss_nonzero >= 0.0);
708 assert!(weight_nonzero >= 0.0);
709 }
710 }
711
712 #[test]
713 fn test_custom_robust_loss() {
714 let custom_loss = RobustLoss::Custom {
715 loss_fn: |r| r * r, weight_fn: |_| 1.0, };
718
719 assert_eq!(custom_loss.loss(2.0), 4.0);
720 assert_eq!(custom_loss.weight(5.0), 1.0);
721 }
722}