1use scirs2_core::ndarray::{Array1, Array2, Axis};
8use scirs2_linalg::compat::ArrayLinalgExt;
9use std::marker::PhantomData;
11
12use sklears_core::{
13 error::{Result, SklearsError},
14 traits::{Estimator, Fit, Predict, Score, Trained, Untrained},
15 types::{Float, Int},
16};
17
18use crate::solver::Solver;
19
20fn safe_mean_axis(arr: &Array2<Float>, axis: Axis) -> Result<Array1<Float>> {
22 if arr.is_empty() {
23 return Err(SklearsError::InvalidInput(
24 "Cannot compute mean of empty array".to_string(),
25 ));
26 }
27 arr.mean_axis(axis).ok_or_else(|| {
28 SklearsError::InvalidInput("Mean computation failed (empty axis)".to_string())
29 })
30}
31
32fn compare_floats(a: &Float, b: &Float) -> Result<std::cmp::Ordering> {
34 a.partial_cmp(b)
35 .ok_or_else(|| SklearsError::InvalidInput("NaN encountered in comparison".to_string()))
36}
37
38#[derive(Debug, Clone)]
40pub struct RidgeClassifierConfig {
41 pub alpha: Float,
43 pub fit_intercept: bool,
45 pub normalize: bool,
47 pub solver: Solver,
49 pub max_iter: Option<usize>,
51 pub tol: Float,
53 pub random_state: Option<u64>,
55}
56
57impl Default for RidgeClassifierConfig {
58 fn default() -> Self {
59 Self {
60 alpha: 1.0,
61 fit_intercept: true,
62 normalize: false,
63 solver: Solver::Auto,
64 max_iter: None,
65 tol: 1e-3,
66 random_state: None,
67 }
68 }
69}
70
71pub struct RidgeClassifier<State = Untrained> {
73 config: RidgeClassifierConfig,
74 state: PhantomData<State>,
75 coef_: Option<Array2<Float>>,
76 intercept_: Option<Array1<Float>>,
77 classes_: Option<Array1<Int>>,
78 n_features_in_: Option<usize>,
79}
80
81impl RidgeClassifier<Untrained> {
82 pub fn new() -> Self {
84 Self {
85 config: RidgeClassifierConfig::default(),
86 state: PhantomData,
87 coef_: None,
88 intercept_: None,
89 classes_: None,
90 n_features_in_: None,
91 }
92 }
93
94 pub fn alpha(mut self, alpha: Float) -> Self {
96 self.config.alpha = alpha;
97 self
98 }
99
100 pub fn fit_intercept(mut self, fit_intercept: bool) -> Self {
102 self.config.fit_intercept = fit_intercept;
103 self
104 }
105
106 pub fn normalize(mut self, normalize: bool) -> Self {
108 self.config.normalize = normalize;
109 self
110 }
111
112 pub fn solver(mut self, solver: Solver) -> Self {
114 self.config.solver = solver;
115 self
116 }
117
118 pub fn tol(mut self, tol: Float) -> Self {
120 self.config.tol = tol;
121 self
122 }
123}
124
125impl Default for RidgeClassifier<Untrained> {
126 fn default() -> Self {
127 Self::new()
128 }
129}
130
131impl Estimator for RidgeClassifier<Untrained> {
132 type Float = Float;
133 type Config = RidgeClassifierConfig;
134 type Error = SklearsError;
135
136 fn config(&self) -> &Self::Config {
137 &self.config
138 }
139}
140
141impl Estimator for RidgeClassifier<Trained> {
142 type Float = Float;
143 type Config = RidgeClassifierConfig;
144 type Error = SklearsError;
145
146 fn config(&self) -> &Self::Config {
147 &self.config
148 }
149}
150
151fn label_binarize(y: &Array1<Int>, classes: &[Int]) -> Array2<Float> {
153 let n_samples = y.len();
154 let n_classes = classes.len();
155
156 if n_classes == 2 {
157 let mut y_bin = Array1::zeros(n_samples);
159 for (i, &label) in y.iter().enumerate() {
160 if label == classes[1] {
161 y_bin[i] = 1.0;
162 } else {
163 y_bin[i] = -1.0;
164 }
165 }
166 y_bin.insert_axis(Axis(1))
167 } else {
168 let mut y_bin = Array2::from_elem((n_samples, n_classes), -1.0);
170 for (i, &label) in y.iter().enumerate() {
171 for (j, &class) in classes.iter().enumerate() {
172 if label == class {
173 y_bin[[i, j]] = 1.0;
174 }
175 }
176 }
177 y_bin
178 }
179}
180
181impl Fit<Array2<Float>, Array1<Int>> for RidgeClassifier<Untrained> {
182 type Fitted = RidgeClassifier<Trained>;
183
184 fn fit(self, x: &Array2<Float>, y: &Array1<Int>) -> Result<Self::Fitted> {
185 let n_samples = x.nrows();
186 let n_features = x.ncols();
187
188 if n_samples != y.len() {
189 return Err(SklearsError::InvalidInput(
190 "X and y must have the same number of samples".to_string(),
191 ));
192 }
193
194 let mut classes: Vec<Int> = y.iter().copied().collect();
196 classes.sort_unstable();
197 classes.dedup();
198 let n_classes = classes.len();
199
200 if n_classes < 2 {
201 return Err(SklearsError::InvalidInput(
202 "At least two classes are required".to_string(),
203 ));
204 }
205
206 let y_bin = label_binarize(y, &classes);
208
209 let (x_centered, y_centered, x_mean, y_mean) = if self.config.fit_intercept {
211 let x_mean = safe_mean_axis(x, Axis(0))?;
212 let y_mean = safe_mean_axis(&y_bin, Axis(0))?;
213 let x_centered = x - &x_mean;
214 let y_centered = if n_classes == 2 {
215 y_bin - y_mean[0]
217 } else {
218 &y_bin - &y_mean
220 };
221 (x_centered, y_centered, Some(x_mean), Some(y_mean))
222 } else {
223 (x.clone(), y_bin.clone(), None, None)
224 };
225
226 let mut coef = Array2::zeros((n_classes, n_features));
228
229 let xt_x = x_centered.t().dot(&x_centered);
231 let xt_x_reg =
232 &xt_x + &(Array2::<Float>::eye(n_features) * self.config.alpha * n_samples as Float);
233
234 if n_classes == 2 {
235 let xt_y = x_centered.t().dot(&y_centered.column(0));
237
238 match xt_x_reg.solve(&xt_y) {
239 Ok(solution) => {
240 coef.row_mut(0).assign(&(-&solution));
241 coef.row_mut(1).assign(&solution);
242 }
243 Err(_) => {
244 return Err(SklearsError::InvalidInput(
245 "Failed to solve linear system".to_string(),
246 ));
247 }
248 }
249 } else {
250 for k in 0..n_classes {
252 let xt_y = x_centered.t().dot(&y_centered.column(k));
253
254 match xt_x_reg.solve(&xt_y) {
255 Ok(solution) => {
256 coef.row_mut(k).assign(&solution);
257 }
258 Err(_) => {
259 return Err(SklearsError::InvalidInput(format!(
260 "Failed to solve linear system for class {}",
261 k
262 )));
263 }
264 }
265 }
266 }
267
268 let intercept = if self.config.fit_intercept {
270 let x_mean = x_mean.expect("x_mean should be Some when fit_intercept is true");
271 let y_mean = y_mean.expect("y_mean should be Some when fit_intercept is true");
272
273 if n_classes == 2 {
274 let intercept_val = y_mean[0] - x_mean.dot(&coef.row(1));
276 Array1::from_vec(vec![-intercept_val, intercept_val])
277 } else {
278 let mut intercept = Array1::zeros(n_classes);
280 for k in 0..n_classes {
281 intercept[k] = y_mean[k] - x_mean.dot(&coef.row(k));
282 }
283 intercept
284 }
285 } else {
286 Array1::zeros(n_classes)
287 };
288
289 Ok(RidgeClassifier {
290 config: self.config,
291 state: PhantomData,
292 coef_: Some(coef),
293 intercept_: Some(intercept),
294 classes_: Some(Array1::from_vec(classes)),
295 n_features_in_: Some(n_features),
296 })
297 }
298}
299
300impl Predict<Array2<Float>, Array1<Int>> for RidgeClassifier<Trained> {
301 fn predict(&self, x: &Array2<Float>) -> Result<Array1<Int>> {
302 let coef = self
303 .coef_
304 .as_ref()
305 .expect("coef_ must be Some in Trained state");
306 let intercept = self
307 .intercept_
308 .as_ref()
309 .expect("intercept_ must be Some in Trained state");
310 let classes = self
311 .classes_
312 .as_ref()
313 .expect("classes_ must be Some in Trained state");
314
315 let scores = x.dot(&coef.t()) + intercept;
317
318 let mut predictions = Vec::with_capacity(scores.nrows());
320 for row in scores.axis_iter(Axis(0)) {
321 let max_idx = row
322 .iter()
323 .enumerate()
324 .max_by(|(_, a), (_, b)| compare_floats(a, b).unwrap_or(std::cmp::Ordering::Equal))
325 .map(|(idx, _)| idx)
326 .ok_or_else(|| SklearsError::InvalidInput("Empty row in scores".to_string()))?;
327 predictions.push(classes[max_idx]);
328 }
329
330 Ok(Array1::from_vec(predictions))
331 }
332}
333
334impl Score<Array2<Float>, Array1<Int>> for RidgeClassifier<Trained> {
335 type Float = Float;
336
337 fn score(&self, x: &Array2<Float>, y: &Array1<Int>) -> Result<Float> {
338 let predictions = self.predict(x)?;
339 let correct = predictions
340 .iter()
341 .zip(y.iter())
342 .filter(|(pred, true_val)| pred == true_val)
343 .count();
344
345 Ok(correct as Float / y.len() as Float)
346 }
347}
348
349impl RidgeClassifier<Trained> {
350 pub fn coef(&self) -> &Array2<Float> {
352 self.coef_
353 .as_ref()
354 .expect("coef_ must be Some in Trained state")
355 }
356
357 pub fn intercept(&self) -> Option<&Array1<Float>> {
359 self.intercept_.as_ref()
360 }
361
362 pub fn classes(&self) -> &Array1<Int> {
364 self.classes_
365 .as_ref()
366 .expect("classes_ must be Some in Trained state")
367 }
368
369 pub fn n_features_in(&self) -> usize {
371 self.n_features_in_
372 .expect("n_features_in_ must be Some in Trained state")
373 }
374
375 pub fn decision_function(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
377 let coef = self
378 .coef_
379 .as_ref()
380 .expect("coef_ must be Some in Trained state");
381 let intercept = self
382 .intercept_
383 .as_ref()
384 .expect("intercept_ must be Some in Trained state");
385
386 Ok(x.dot(&coef.t()) + intercept)
387 }
388}
389
390#[allow(non_snake_case)]
391#[cfg(test)]
392mod tests {
393 use super::*;
394
395 use scirs2_core::ndarray::array;
396
397 #[test]
398 fn test_ridge_classifier_binary() {
399 let x = array![
401 [1.0, 1.0],
402 [2.0, 2.0],
403 [3.0, 3.0],
404 [-1.0, -1.0],
405 [-2.0, -2.0],
406 [-3.0, -3.0],
407 ];
408 let y = array![1, 1, 1, 0, 0, 0];
409
410 let model = RidgeClassifier::new()
411 .alpha(1.0)
412 .fit(&x, &y)
413 .expect("model fitting should succeed");
414
415 let _predictions = model.predict(&x).expect("prediction should succeed");
416 let accuracy = model.score(&x, &y).expect("scoring should succeed");
417
418 assert!(accuracy > 0.8);
420
421 assert_eq!(model.classes().len(), 2);
423 assert_eq!(model.coef().nrows(), 2);
424 }
425
426 #[test]
427 fn test_ridge_classifier_multiclass() {
428 let x = array![
429 [1.0, 1.0],
430 [2.0, 2.0],
431 [-1.0, -1.0],
432 [-2.0, -2.0],
433 [1.0, -1.0],
434 [2.0, -2.0],
435 ];
436 let y = array![0, 0, 1, 1, 2, 2];
437
438 let model = RidgeClassifier::new()
439 .alpha(0.1)
440 .fit(&x, &y)
441 .expect("model fitting should succeed");
442
443 let accuracy = model.score(&x, &y).expect("scoring should succeed");
444 assert!(accuracy > 0.8);
445
446 assert_eq!(model.classes().len(), 3);
448 assert_eq!(model.coef().nrows(), 3);
449 }
450
451 #[test]
452 fn test_ridge_classifier_no_intercept() {
453 let x = array![[1.0, 1.0], [2.0, 2.0], [-1.0, -1.0], [-2.0, -2.0],];
454 let y = array![1, 1, 0, 0];
455
456 let model = RidgeClassifier::new()
457 .fit_intercept(false)
458 .fit(&x, &y)
459 .expect("operation should succeed");
460
461 let intercept = model.intercept().expect("intercept should be available");
462 assert!(intercept.iter().all(|&v| v == 0.0));
463 }
464
465 #[test]
466 fn test_ridge_classifier_strong_regularization() {
467 let x = array![[1.0, 0.0], [2.0, 0.0], [0.0, 1.0], [0.0, 2.0],];
468 let y = array![0, 0, 1, 1];
469
470 let model = RidgeClassifier::new()
472 .alpha(1000.0)
473 .fit(&x, &y)
474 .expect("model fitting should succeed");
475
476 let coef = model.coef();
477 assert!(coef.iter().all(|&c| c.abs() < 0.1));
478 }
479
480 #[test]
481 fn test_ridge_classifier_decision_function() {
482 let x = array![[1.0, 1.0], [-1.0, -1.0],];
483 let y = array![1, 0];
484
485 let model = RidgeClassifier::new()
486 .fit(&x, &y)
487 .expect("model fitting should succeed");
488
489 let decision = model
490 .decision_function(&x)
491 .expect("operation should succeed");
492
493 assert_eq!(decision.ncols(), 2);
495
496 let predictions = model.predict(&x).expect("prediction should succeed");
498 for (i, &pred) in predictions.iter().enumerate() {
499 let scores = decision.row(i);
500 let max_idx = scores
501 .iter()
502 .enumerate()
503 .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("operation should succeed"))
504 .map(|(idx, _)| idx)
505 .expect("operation should succeed");
506 assert_eq!(model.classes()[max_idx], pred);
507 }
508 }
509
510 #[test]
511 fn test_label_binarize() {
512 let y = array![0, 1, 1, 0];
514 let classes = vec![0, 1];
515 let y_bin = label_binarize(&y, &classes);
516
517 assert_eq!(y_bin.shape(), &[4, 1]);
518 assert_eq!(y_bin[[0, 0]], -1.0);
519 assert_eq!(y_bin[[1, 0]], 1.0);
520
521 let y = array![0, 1, 2, 0];
523 let classes = vec![0, 1, 2];
524 let y_bin = label_binarize(&y, &classes);
525
526 assert_eq!(y_bin.shape(), &[4, 3]);
527 assert_eq!(y_bin[[0, 0]], 1.0);
528 assert_eq!(y_bin[[0, 1]], -1.0);
529 assert_eq!(y_bin[[2, 2]], 1.0);
530 }
531}