1use scirs2_core::ndarray::{Array1, Array2, Axis};
8use scirs2_linalg::compat::ArrayLinalgExt;
9use std::marker::PhantomData;
11
12use sklears_core::{
13 error::{Result, SklearsError},
14 traits::{Estimator, Fit, Predict, Score, Trained, Untrained},
15 types::{Float, Int},
16};
17
18use crate::solver::Solver;
19
20#[derive(Debug, Clone)]
22pub struct RidgeClassifierConfig {
23 pub alpha: Float,
25 pub fit_intercept: bool,
27 pub normalize: bool,
29 pub solver: Solver,
31 pub max_iter: Option<usize>,
33 pub tol: Float,
35 pub random_state: Option<u64>,
37}
38
39impl Default for RidgeClassifierConfig {
40 fn default() -> Self {
41 Self {
42 alpha: 1.0,
43 fit_intercept: true,
44 normalize: false,
45 solver: Solver::Auto,
46 max_iter: None,
47 tol: 1e-3,
48 random_state: None,
49 }
50 }
51}
52
53pub struct RidgeClassifier<State = Untrained> {
55 config: RidgeClassifierConfig,
56 state: PhantomData<State>,
57 coef_: Option<Array2<Float>>,
58 intercept_: Option<Array1<Float>>,
59 classes_: Option<Array1<Int>>,
60 n_features_in_: Option<usize>,
61}
62
63impl RidgeClassifier<Untrained> {
64 pub fn new() -> Self {
66 Self {
67 config: RidgeClassifierConfig::default(),
68 state: PhantomData,
69 coef_: None,
70 intercept_: None,
71 classes_: None,
72 n_features_in_: None,
73 }
74 }
75
76 pub fn alpha(mut self, alpha: Float) -> Self {
78 self.config.alpha = alpha;
79 self
80 }
81
82 pub fn fit_intercept(mut self, fit_intercept: bool) -> Self {
84 self.config.fit_intercept = fit_intercept;
85 self
86 }
87
88 pub fn normalize(mut self, normalize: bool) -> Self {
90 self.config.normalize = normalize;
91 self
92 }
93
94 pub fn solver(mut self, solver: Solver) -> Self {
96 self.config.solver = solver;
97 self
98 }
99
100 pub fn tol(mut self, tol: Float) -> Self {
102 self.config.tol = tol;
103 self
104 }
105}
106
107impl Default for RidgeClassifier<Untrained> {
108 fn default() -> Self {
109 Self::new()
110 }
111}
112
113impl Estimator for RidgeClassifier<Untrained> {
114 type Float = Float;
115 type Config = RidgeClassifierConfig;
116 type Error = SklearsError;
117
118 fn config(&self) -> &Self::Config {
119 &self.config
120 }
121}
122
123impl Estimator for RidgeClassifier<Trained> {
124 type Float = Float;
125 type Config = RidgeClassifierConfig;
126 type Error = SklearsError;
127
128 fn config(&self) -> &Self::Config {
129 &self.config
130 }
131}
132
133fn label_binarize(y: &Array1<Int>, classes: &[Int]) -> Array2<Float> {
135 let n_samples = y.len();
136 let n_classes = classes.len();
137
138 if n_classes == 2 {
139 let mut y_bin = Array1::zeros(n_samples);
141 for (i, &label) in y.iter().enumerate() {
142 if label == classes[1] {
143 y_bin[i] = 1.0;
144 } else {
145 y_bin[i] = -1.0;
146 }
147 }
148 y_bin.insert_axis(Axis(1))
149 } else {
150 let mut y_bin = Array2::from_elem((n_samples, n_classes), -1.0);
152 for (i, &label) in y.iter().enumerate() {
153 for (j, &class) in classes.iter().enumerate() {
154 if label == class {
155 y_bin[[i, j]] = 1.0;
156 }
157 }
158 }
159 y_bin
160 }
161}
162
163impl Fit<Array2<Float>, Array1<Int>> for RidgeClassifier<Untrained> {
164 type Fitted = RidgeClassifier<Trained>;
165
166 fn fit(self, x: &Array2<Float>, y: &Array1<Int>) -> Result<Self::Fitted> {
167 let n_samples = x.nrows();
168 let n_features = x.ncols();
169
170 if n_samples != y.len() {
171 return Err(SklearsError::InvalidInput(
172 "X and y must have the same number of samples".to_string(),
173 ));
174 }
175
176 let mut classes: Vec<Int> = y.iter().copied().collect();
178 classes.sort_unstable();
179 classes.dedup();
180 let n_classes = classes.len();
181
182 if n_classes < 2 {
183 return Err(SklearsError::InvalidInput(
184 "At least two classes are required".to_string(),
185 ));
186 }
187
188 let y_bin = label_binarize(y, &classes);
190
191 let (x_centered, y_centered, x_mean, y_mean) = if self.config.fit_intercept {
193 let x_mean = x.mean_axis(Axis(0)).unwrap();
194 let y_mean = y_bin.mean_axis(Axis(0)).unwrap();
195 let x_centered = x - &x_mean;
196 let y_centered = if n_classes == 2 {
197 y_bin - y_mean[0]
199 } else {
200 &y_bin - &y_mean
202 };
203 (x_centered, y_centered, Some(x_mean), Some(y_mean))
204 } else {
205 (x.clone(), y_bin.clone(), None, None)
206 };
207
208 let mut coef = Array2::zeros((n_classes, n_features));
210
211 let xt_x = x_centered.t().dot(&x_centered);
213 let xt_x_reg =
214 &xt_x + &(Array2::<Float>::eye(n_features) * self.config.alpha * n_samples as Float);
215
216 if n_classes == 2 {
217 let xt_y = x_centered.t().dot(&y_centered.column(0));
219
220 match xt_x_reg.solve(&xt_y) {
221 Ok(solution) => {
222 coef.row_mut(0).assign(&(-&solution));
223 coef.row_mut(1).assign(&solution);
224 }
225 Err(_) => {
226 return Err(SklearsError::InvalidInput(
227 "Failed to solve linear system".to_string(),
228 ));
229 }
230 }
231 } else {
232 for k in 0..n_classes {
234 let xt_y = x_centered.t().dot(&y_centered.column(k));
235
236 match xt_x_reg.solve(&xt_y) {
237 Ok(solution) => {
238 coef.row_mut(k).assign(&solution);
239 }
240 Err(_) => {
241 return Err(SklearsError::InvalidInput(format!(
242 "Failed to solve linear system for class {}",
243 k
244 )));
245 }
246 }
247 }
248 }
249
250 let intercept = if self.config.fit_intercept {
252 let x_mean = x_mean.unwrap();
253 let y_mean = y_mean.unwrap();
254
255 if n_classes == 2 {
256 let intercept_val = y_mean[0] - x_mean.dot(&coef.row(1));
258 Array1::from_vec(vec![-intercept_val, intercept_val])
259 } else {
260 let mut intercept = Array1::zeros(n_classes);
262 for k in 0..n_classes {
263 intercept[k] = y_mean[k] - x_mean.dot(&coef.row(k));
264 }
265 intercept
266 }
267 } else {
268 Array1::zeros(n_classes)
269 };
270
271 Ok(RidgeClassifier {
272 config: self.config,
273 state: PhantomData,
274 coef_: Some(coef),
275 intercept_: Some(intercept),
276 classes_: Some(Array1::from_vec(classes)),
277 n_features_in_: Some(n_features),
278 })
279 }
280}
281
282impl Predict<Array2<Float>, Array1<Int>> for RidgeClassifier<Trained> {
283 fn predict(&self, x: &Array2<Float>) -> Result<Array1<Int>> {
284 let coef = self.coef_.as_ref().unwrap();
285 let intercept = self.intercept_.as_ref().unwrap();
286 let classes = self.classes_.as_ref().unwrap();
287
288 let scores = x.dot(&coef.t()) + intercept;
290
291 let predictions = scores
293 .axis_iter(Axis(0))
294 .map(|row| {
295 let max_idx = row
296 .iter()
297 .enumerate()
298 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
299 .map(|(idx, _)| idx)
300 .unwrap();
301 classes[max_idx]
302 })
303 .collect();
304
305 Ok(Array1::from_vec(predictions))
306 }
307}
308
309impl Score<Array2<Float>, Array1<Int>> for RidgeClassifier<Trained> {
310 type Float = Float;
311
312 fn score(&self, x: &Array2<Float>, y: &Array1<Int>) -> Result<Float> {
313 let predictions = self.predict(x)?;
314 let correct = predictions
315 .iter()
316 .zip(y.iter())
317 .filter(|(pred, true_val)| pred == true_val)
318 .count();
319
320 Ok(correct as Float / y.len() as Float)
321 }
322}
323
324impl RidgeClassifier<Trained> {
325 pub fn coef(&self) -> &Array2<Float> {
327 self.coef_.as_ref().unwrap()
328 }
329
330 pub fn intercept(&self) -> Option<&Array1<Float>> {
332 self.intercept_.as_ref()
333 }
334
335 pub fn classes(&self) -> &Array1<Int> {
337 self.classes_.as_ref().unwrap()
338 }
339
340 pub fn n_features_in(&self) -> usize {
342 self.n_features_in_.unwrap()
343 }
344
345 pub fn decision_function(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
347 let coef = self.coef_.as_ref().unwrap();
348 let intercept = self.intercept_.as_ref().unwrap();
349
350 Ok(x.dot(&coef.t()) + intercept)
351 }
352}
353
354#[allow(non_snake_case)]
355#[cfg(test)]
356mod tests {
357 use super::*;
358
359 use scirs2_core::ndarray::array;
360
361 #[test]
362 fn test_ridge_classifier_binary() {
363 let x = array![
365 [1.0, 1.0],
366 [2.0, 2.0],
367 [3.0, 3.0],
368 [-1.0, -1.0],
369 [-2.0, -2.0],
370 [-3.0, -3.0],
371 ];
372 let y = array![1, 1, 1, 0, 0, 0];
373
374 let model = RidgeClassifier::new().alpha(1.0).fit(&x, &y).unwrap();
375
376 let _predictions = model.predict(&x).unwrap();
377 let accuracy = model.score(&x, &y).unwrap();
378
379 assert!(accuracy > 0.8);
381
382 assert_eq!(model.classes().len(), 2);
384 assert_eq!(model.coef().nrows(), 2);
385 }
386
387 #[test]
388 fn test_ridge_classifier_multiclass() {
389 let x = array![
390 [1.0, 1.0],
391 [2.0, 2.0],
392 [-1.0, -1.0],
393 [-2.0, -2.0],
394 [1.0, -1.0],
395 [2.0, -2.0],
396 ];
397 let y = array![0, 0, 1, 1, 2, 2];
398
399 let model = RidgeClassifier::new().alpha(0.1).fit(&x, &y).unwrap();
400
401 let accuracy = model.score(&x, &y).unwrap();
402 assert!(accuracy > 0.8);
403
404 assert_eq!(model.classes().len(), 3);
406 assert_eq!(model.coef().nrows(), 3);
407 }
408
409 #[test]
410 fn test_ridge_classifier_no_intercept() {
411 let x = array![[1.0, 1.0], [2.0, 2.0], [-1.0, -1.0], [-2.0, -2.0],];
412 let y = array![1, 1, 0, 0];
413
414 let model = RidgeClassifier::new()
415 .fit_intercept(false)
416 .fit(&x, &y)
417 .unwrap();
418
419 let intercept = model.intercept().unwrap();
420 assert!(intercept.iter().all(|&v| v == 0.0));
421 }
422
423 #[test]
424 fn test_ridge_classifier_strong_regularization() {
425 let x = array![[1.0, 0.0], [2.0, 0.0], [0.0, 1.0], [0.0, 2.0],];
426 let y = array![0, 0, 1, 1];
427
428 let model = RidgeClassifier::new().alpha(1000.0).fit(&x, &y).unwrap();
430
431 let coef = model.coef();
432 assert!(coef.iter().all(|&c| c.abs() < 0.1));
433 }
434
435 #[test]
436 fn test_ridge_classifier_decision_function() {
437 let x = array![[1.0, 1.0], [-1.0, -1.0],];
438 let y = array![1, 0];
439
440 let model = RidgeClassifier::new().fit(&x, &y).unwrap();
441
442 let decision = model.decision_function(&x).unwrap();
443
444 assert_eq!(decision.ncols(), 2);
446
447 let predictions = model.predict(&x).unwrap();
449 for (i, &pred) in predictions.iter().enumerate() {
450 let scores = decision.row(i);
451 let max_idx = scores
452 .iter()
453 .enumerate()
454 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
455 .map(|(idx, _)| idx)
456 .unwrap();
457 assert_eq!(model.classes()[max_idx], pred);
458 }
459 }
460
461 #[test]
462 fn test_label_binarize() {
463 let y = array![0, 1, 1, 0];
465 let classes = vec![0, 1];
466 let y_bin = label_binarize(&y, &classes);
467
468 assert_eq!(y_bin.shape(), &[4, 1]);
469 assert_eq!(y_bin[[0, 0]], -1.0);
470 assert_eq!(y_bin[[1, 0]], 1.0);
471
472 let y = array![0, 1, 2, 0];
474 let classes = vec![0, 1, 2];
475 let y_bin = label_binarize(&y, &classes);
476
477 assert_eq!(y_bin.shape(), &[4, 3]);
478 assert_eq!(y_bin[[0, 0]], 1.0);
479 assert_eq!(y_bin[[0, 1]], -1.0);
480 assert_eq!(y_bin[[2, 2]], 1.0);
481 }
482}