sklears_linear/
ridge_classifier.rs

1//! Ridge Classifier
2//!
3//! Classifier using Ridge regression. This classifier first converts binary targets to
4//! {-1, 1} and then treats the problem as a regression task (multi-output regression in
5//! the multiclass case).
6
7use scirs2_core::ndarray::{Array1, Array2, Axis};
8use scirs2_linalg::compat::ArrayLinalgExt;
9// Removed SVD import - using ArrayLinalgExt for both solve and svd methods
10use std::marker::PhantomData;
11
12use sklears_core::{
13    error::{Result, SklearsError},
14    traits::{Estimator, Fit, Predict, Score, Trained, Untrained},
15    types::{Float, Int},
16};
17
18use crate::solver::Solver;
19
20/// Configuration for RidgeClassifier
21#[derive(Debug, Clone)]
22pub struct RidgeClassifierConfig {
23    /// Regularization strength; must be a positive float
24    pub alpha: Float,
25    /// Whether to fit the intercept
26    pub fit_intercept: bool,
27    /// If True, the regressors X will be normalized before regression
28    pub normalize: bool,
29    /// Solver to use in the computational routines
30    pub solver: Solver,
31    /// Maximum number of iterations for iterative solvers
32    pub max_iter: Option<usize>,
33    /// Precision of the solution
34    pub tol: Float,
35    /// Random state for shuffling the data
36    pub random_state: Option<u64>,
37}
38
39impl Default for RidgeClassifierConfig {
40    fn default() -> Self {
41        Self {
42            alpha: 1.0,
43            fit_intercept: true,
44            normalize: false,
45            solver: Solver::Auto,
46            max_iter: None,
47            tol: 1e-3,
48            random_state: None,
49        }
50    }
51}
52
53/// Ridge Classifier
54pub struct RidgeClassifier<State = Untrained> {
55    config: RidgeClassifierConfig,
56    state: PhantomData<State>,
57    coef_: Option<Array2<Float>>,
58    intercept_: Option<Array1<Float>>,
59    classes_: Option<Array1<Int>>,
60    n_features_in_: Option<usize>,
61}
62
63impl RidgeClassifier<Untrained> {
64    /// Create a new RidgeClassifier with default configuration
65    pub fn new() -> Self {
66        Self {
67            config: RidgeClassifierConfig::default(),
68            state: PhantomData,
69            coef_: None,
70            intercept_: None,
71            classes_: None,
72            n_features_in_: None,
73        }
74    }
75
76    /// Set the regularization strength
77    pub fn alpha(mut self, alpha: Float) -> Self {
78        self.config.alpha = alpha;
79        self
80    }
81
82    /// Set whether to fit the intercept
83    pub fn fit_intercept(mut self, fit_intercept: bool) -> Self {
84        self.config.fit_intercept = fit_intercept;
85        self
86    }
87
88    /// Set whether to normalize the features
89    pub fn normalize(mut self, normalize: bool) -> Self {
90        self.config.normalize = normalize;
91        self
92    }
93
94    /// Set the solver
95    pub fn solver(mut self, solver: Solver) -> Self {
96        self.config.solver = solver;
97        self
98    }
99
100    /// Set the tolerance
101    pub fn tol(mut self, tol: Float) -> Self {
102        self.config.tol = tol;
103        self
104    }
105}
106
107impl Default for RidgeClassifier<Untrained> {
108    fn default() -> Self {
109        Self::new()
110    }
111}
112
113impl Estimator for RidgeClassifier<Untrained> {
114    type Float = Float;
115    type Config = RidgeClassifierConfig;
116    type Error = SklearsError;
117
118    fn config(&self) -> &Self::Config {
119        &self.config
120    }
121}
122
123impl Estimator for RidgeClassifier<Trained> {
124    type Float = Float;
125    type Config = RidgeClassifierConfig;
126    type Error = SklearsError;
127
128    fn config(&self) -> &Self::Config {
129        &self.config
130    }
131}
132
133/// Convert class labels to regression targets
134fn label_binarize(y: &Array1<Int>, classes: &[Int]) -> Array2<Float> {
135    let n_samples = y.len();
136    let n_classes = classes.len();
137
138    if n_classes == 2 {
139        // Binary case: convert to {-1, 1}
140        let mut y_bin = Array1::zeros(n_samples);
141        for (i, &label) in y.iter().enumerate() {
142            if label == classes[1] {
143                y_bin[i] = 1.0;
144            } else {
145                y_bin[i] = -1.0;
146            }
147        }
148        y_bin.insert_axis(Axis(1))
149    } else {
150        // Multi-class case: one-vs-all encoding
151        let mut y_bin = Array2::from_elem((n_samples, n_classes), -1.0);
152        for (i, &label) in y.iter().enumerate() {
153            for (j, &class) in classes.iter().enumerate() {
154                if label == class {
155                    y_bin[[i, j]] = 1.0;
156                }
157            }
158        }
159        y_bin
160    }
161}
162
163impl Fit<Array2<Float>, Array1<Int>> for RidgeClassifier<Untrained> {
164    type Fitted = RidgeClassifier<Trained>;
165
166    fn fit(self, x: &Array2<Float>, y: &Array1<Int>) -> Result<Self::Fitted> {
167        let n_samples = x.nrows();
168        let n_features = x.ncols();
169
170        if n_samples != y.len() {
171            return Err(SklearsError::InvalidInput(
172                "X and y must have the same number of samples".to_string(),
173            ));
174        }
175
176        // Get unique classes
177        let mut classes: Vec<Int> = y.iter().copied().collect();
178        classes.sort_unstable();
179        classes.dedup();
180        let n_classes = classes.len();
181
182        if n_classes < 2 {
183            return Err(SklearsError::InvalidInput(
184                "At least two classes are required".to_string(),
185            ));
186        }
187
188        // Convert labels to regression targets
189        let y_bin = label_binarize(y, &classes);
190
191        // Center X and y if fitting intercept
192        let (x_centered, y_centered, x_mean, y_mean) = if self.config.fit_intercept {
193            let x_mean = x.mean_axis(Axis(0)).unwrap();
194            let y_mean = y_bin.mean_axis(Axis(0)).unwrap();
195            let x_centered = x - &x_mean;
196            let y_centered = if n_classes == 2 {
197                // For binary case, just center the single column
198                y_bin - y_mean[0]
199            } else {
200                // For multi-class, center each column
201                &y_bin - &y_mean
202            };
203            (x_centered, y_centered, Some(x_mean), Some(y_mean))
204        } else {
205            (x.clone(), y_bin.clone(), None, None)
206        };
207
208        // Solve the ridge regression problem for each class
209        let mut coef = Array2::zeros((n_classes, n_features));
210
211        // Compute X^T X + alpha * I
212        let xt_x = x_centered.t().dot(&x_centered);
213        let xt_x_reg =
214            &xt_x + &(Array2::<Float>::eye(n_features) * self.config.alpha * n_samples as Float);
215
216        if n_classes == 2 {
217            // Binary case: solve once
218            let xt_y = x_centered.t().dot(&y_centered.column(0));
219
220            match xt_x_reg.solve(&xt_y) {
221                Ok(solution) => {
222                    coef.row_mut(0).assign(&(-&solution));
223                    coef.row_mut(1).assign(&solution);
224                }
225                Err(_) => {
226                    return Err(SklearsError::InvalidInput(
227                        "Failed to solve linear system".to_string(),
228                    ));
229                }
230            }
231        } else {
232            // Multi-class case: solve for each class
233            for k in 0..n_classes {
234                let xt_y = x_centered.t().dot(&y_centered.column(k));
235
236                match xt_x_reg.solve(&xt_y) {
237                    Ok(solution) => {
238                        coef.row_mut(k).assign(&solution);
239                    }
240                    Err(_) => {
241                        return Err(SklearsError::InvalidInput(format!(
242                            "Failed to solve linear system for class {}",
243                            k
244                        )));
245                    }
246                }
247            }
248        }
249
250        // Compute intercept if needed
251        let intercept = if self.config.fit_intercept {
252            let x_mean = x_mean.unwrap();
253            let y_mean = y_mean.unwrap();
254
255            if n_classes == 2 {
256                // Binary case
257                let intercept_val = y_mean[0] - x_mean.dot(&coef.row(1));
258                Array1::from_vec(vec![-intercept_val, intercept_val])
259            } else {
260                // Multi-class case
261                let mut intercept = Array1::zeros(n_classes);
262                for k in 0..n_classes {
263                    intercept[k] = y_mean[k] - x_mean.dot(&coef.row(k));
264                }
265                intercept
266            }
267        } else {
268            Array1::zeros(n_classes)
269        };
270
271        Ok(RidgeClassifier {
272            config: self.config,
273            state: PhantomData,
274            coef_: Some(coef),
275            intercept_: Some(intercept),
276            classes_: Some(Array1::from_vec(classes)),
277            n_features_in_: Some(n_features),
278        })
279    }
280}
281
282impl Predict<Array2<Float>, Array1<Int>> for RidgeClassifier<Trained> {
283    fn predict(&self, x: &Array2<Float>) -> Result<Array1<Int>> {
284        let coef = self.coef_.as_ref().unwrap();
285        let intercept = self.intercept_.as_ref().unwrap();
286        let classes = self.classes_.as_ref().unwrap();
287
288        // Compute decision function
289        let scores = x.dot(&coef.t()) + intercept;
290
291        // Predict class with maximum score
292        let predictions = scores
293            .axis_iter(Axis(0))
294            .map(|row| {
295                let max_idx = row
296                    .iter()
297                    .enumerate()
298                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
299                    .map(|(idx, _)| idx)
300                    .unwrap();
301                classes[max_idx]
302            })
303            .collect();
304
305        Ok(Array1::from_vec(predictions))
306    }
307}
308
309impl Score<Array2<Float>, Array1<Int>> for RidgeClassifier<Trained> {
310    type Float = Float;
311
312    fn score(&self, x: &Array2<Float>, y: &Array1<Int>) -> Result<Float> {
313        let predictions = self.predict(x)?;
314        let correct = predictions
315            .iter()
316            .zip(y.iter())
317            .filter(|(pred, true_val)| pred == true_val)
318            .count();
319
320        Ok(correct as Float / y.len() as Float)
321    }
322}
323
324impl RidgeClassifier<Trained> {
325    /// Get the coefficients
326    pub fn coef(&self) -> &Array2<Float> {
327        self.coef_.as_ref().unwrap()
328    }
329
330    /// Get the intercept
331    pub fn intercept(&self) -> Option<&Array1<Float>> {
332        self.intercept_.as_ref()
333    }
334
335    /// Get the classes
336    pub fn classes(&self) -> &Array1<Int> {
337        self.classes_.as_ref().unwrap()
338    }
339
340    /// Get the number of features seen during fit
341    pub fn n_features_in(&self) -> usize {
342        self.n_features_in_.unwrap()
343    }
344
345    /// Get decision function (raw scores)
346    pub fn decision_function(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
347        let coef = self.coef_.as_ref().unwrap();
348        let intercept = self.intercept_.as_ref().unwrap();
349
350        Ok(x.dot(&coef.t()) + intercept)
351    }
352}
353
354#[allow(non_snake_case)]
355#[cfg(test)]
356mod tests {
357    use super::*;
358
359    use scirs2_core::ndarray::array;
360
361    #[test]
362    fn test_ridge_classifier_binary() {
363        // Simple linearly separable data
364        let x = array![
365            [1.0, 1.0],
366            [2.0, 2.0],
367            [3.0, 3.0],
368            [-1.0, -1.0],
369            [-2.0, -2.0],
370            [-3.0, -3.0],
371        ];
372        let y = array![1, 1, 1, 0, 0, 0];
373
374        let model = RidgeClassifier::new().alpha(1.0).fit(&x, &y).unwrap();
375
376        let _predictions = model.predict(&x).unwrap();
377        let accuracy = model.score(&x, &y).unwrap();
378
379        // Should achieve good classification on this simple data
380        assert!(accuracy > 0.8);
381
382        // Check binary class structure
383        assert_eq!(model.classes().len(), 2);
384        assert_eq!(model.coef().nrows(), 2);
385    }
386
387    #[test]
388    fn test_ridge_classifier_multiclass() {
389        let x = array![
390            [1.0, 1.0],
391            [2.0, 2.0],
392            [-1.0, -1.0],
393            [-2.0, -2.0],
394            [1.0, -1.0],
395            [2.0, -2.0],
396        ];
397        let y = array![0, 0, 1, 1, 2, 2];
398
399        let model = RidgeClassifier::new().alpha(0.1).fit(&x, &y).unwrap();
400
401        let accuracy = model.score(&x, &y).unwrap();
402        assert!(accuracy > 0.8);
403
404        // Check that we have the right number of classes
405        assert_eq!(model.classes().len(), 3);
406        assert_eq!(model.coef().nrows(), 3);
407    }
408
409    #[test]
410    fn test_ridge_classifier_no_intercept() {
411        let x = array![[1.0, 1.0], [2.0, 2.0], [-1.0, -1.0], [-2.0, -2.0],];
412        let y = array![1, 1, 0, 0];
413
414        let model = RidgeClassifier::new()
415            .fit_intercept(false)
416            .fit(&x, &y)
417            .unwrap();
418
419        let intercept = model.intercept().unwrap();
420        assert!(intercept.iter().all(|&v| v == 0.0));
421    }
422
423    #[test]
424    fn test_ridge_classifier_strong_regularization() {
425        let x = array![[1.0, 0.0], [2.0, 0.0], [0.0, 1.0], [0.0, 2.0],];
426        let y = array![0, 0, 1, 1];
427
428        // With very high alpha, coefficients should be small
429        let model = RidgeClassifier::new().alpha(1000.0).fit(&x, &y).unwrap();
430
431        let coef = model.coef();
432        assert!(coef.iter().all(|&c| c.abs() < 0.1));
433    }
434
435    #[test]
436    fn test_ridge_classifier_decision_function() {
437        let x = array![[1.0, 1.0], [-1.0, -1.0],];
438        let y = array![1, 0];
439
440        let model = RidgeClassifier::new().fit(&x, &y).unwrap();
441
442        let decision = model.decision_function(&x).unwrap();
443
444        // For binary classification, we should have 2 columns
445        assert_eq!(decision.ncols(), 2);
446
447        // The predicted class should have the highest score
448        let predictions = model.predict(&x).unwrap();
449        for (i, &pred) in predictions.iter().enumerate() {
450            let scores = decision.row(i);
451            let max_idx = scores
452                .iter()
453                .enumerate()
454                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
455                .map(|(idx, _)| idx)
456                .unwrap();
457            assert_eq!(model.classes()[max_idx], pred);
458        }
459    }
460
461    #[test]
462    fn test_label_binarize() {
463        // Test binary case
464        let y = array![0, 1, 1, 0];
465        let classes = vec![0, 1];
466        let y_bin = label_binarize(&y, &classes);
467
468        assert_eq!(y_bin.shape(), &[4, 1]);
469        assert_eq!(y_bin[[0, 0]], -1.0);
470        assert_eq!(y_bin[[1, 0]], 1.0);
471
472        // Test multi-class case
473        let y = array![0, 1, 2, 0];
474        let classes = vec![0, 1, 2];
475        let y_bin = label_binarize(&y, &classes);
476
477        assert_eq!(y_bin.shape(), &[4, 3]);
478        assert_eq!(y_bin[[0, 0]], 1.0);
479        assert_eq!(y_bin[[0, 1]], -1.0);
480        assert_eq!(y_bin[[2, 2]], 1.0);
481    }
482}