Skip to main content

sklears_linear/
ridge_classifier.rs

1//! Ridge Classifier
2//!
3//! Classifier using Ridge regression. This classifier first converts binary targets to
4//! {-1, 1} and then treats the problem as a regression task (multi-output regression in
5//! the multiclass case).
6
7use scirs2_core::ndarray::{Array1, Array2, Axis};
8use scirs2_linalg::compat::ArrayLinalgExt;
9// Removed SVD import - using ArrayLinalgExt for both solve and svd methods
10use std::marker::PhantomData;
11
12use sklears_core::{
13    error::{Result, SklearsError},
14    traits::{Estimator, Fit, Predict, Score, Trained, Untrained},
15    types::{Float, Int},
16};
17
18use crate::solver::Solver;
19
20// Helper function for safe mean computation along axis
21fn safe_mean_axis(arr: &Array2<Float>, axis: Axis) -> Result<Array1<Float>> {
22    if arr.is_empty() {
23        return Err(SklearsError::InvalidInput(
24            "Cannot compute mean of empty array".to_string(),
25        ));
26    }
27    arr.mean_axis(axis).ok_or_else(|| {
28        SklearsError::InvalidInput("Mean computation failed (empty axis)".to_string())
29    })
30}
31
32// Helper function for NaN-safe float comparison
33fn compare_floats(a: &Float, b: &Float) -> Result<std::cmp::Ordering> {
34    a.partial_cmp(b)
35        .ok_or_else(|| SklearsError::InvalidInput("NaN encountered in comparison".to_string()))
36}
37
38/// Configuration for RidgeClassifier
39#[derive(Debug, Clone)]
40pub struct RidgeClassifierConfig {
41    /// Regularization strength; must be a positive float
42    pub alpha: Float,
43    /// Whether to fit the intercept
44    pub fit_intercept: bool,
45    /// If True, the regressors X will be normalized before regression
46    pub normalize: bool,
47    /// Solver to use in the computational routines
48    pub solver: Solver,
49    /// Maximum number of iterations for iterative solvers
50    pub max_iter: Option<usize>,
51    /// Precision of the solution
52    pub tol: Float,
53    /// Random state for shuffling the data
54    pub random_state: Option<u64>,
55}
56
57impl Default for RidgeClassifierConfig {
58    fn default() -> Self {
59        Self {
60            alpha: 1.0,
61            fit_intercept: true,
62            normalize: false,
63            solver: Solver::Auto,
64            max_iter: None,
65            tol: 1e-3,
66            random_state: None,
67        }
68    }
69}
70
71/// Ridge Classifier
72pub struct RidgeClassifier<State = Untrained> {
73    config: RidgeClassifierConfig,
74    state: PhantomData<State>,
75    coef_: Option<Array2<Float>>,
76    intercept_: Option<Array1<Float>>,
77    classes_: Option<Array1<Int>>,
78    n_features_in_: Option<usize>,
79}
80
81impl RidgeClassifier<Untrained> {
82    /// Create a new RidgeClassifier with default configuration
83    pub fn new() -> Self {
84        Self {
85            config: RidgeClassifierConfig::default(),
86            state: PhantomData,
87            coef_: None,
88            intercept_: None,
89            classes_: None,
90            n_features_in_: None,
91        }
92    }
93
94    /// Set the regularization strength
95    pub fn alpha(mut self, alpha: Float) -> Self {
96        self.config.alpha = alpha;
97        self
98    }
99
100    /// Set whether to fit the intercept
101    pub fn fit_intercept(mut self, fit_intercept: bool) -> Self {
102        self.config.fit_intercept = fit_intercept;
103        self
104    }
105
106    /// Set whether to normalize the features
107    pub fn normalize(mut self, normalize: bool) -> Self {
108        self.config.normalize = normalize;
109        self
110    }
111
112    /// Set the solver
113    pub fn solver(mut self, solver: Solver) -> Self {
114        self.config.solver = solver;
115        self
116    }
117
118    /// Set the tolerance
119    pub fn tol(mut self, tol: Float) -> Self {
120        self.config.tol = tol;
121        self
122    }
123}
124
125impl Default for RidgeClassifier<Untrained> {
126    fn default() -> Self {
127        Self::new()
128    }
129}
130
131impl Estimator for RidgeClassifier<Untrained> {
132    type Float = Float;
133    type Config = RidgeClassifierConfig;
134    type Error = SklearsError;
135
136    fn config(&self) -> &Self::Config {
137        &self.config
138    }
139}
140
141impl Estimator for RidgeClassifier<Trained> {
142    type Float = Float;
143    type Config = RidgeClassifierConfig;
144    type Error = SklearsError;
145
146    fn config(&self) -> &Self::Config {
147        &self.config
148    }
149}
150
151/// Convert class labels to regression targets
152fn label_binarize(y: &Array1<Int>, classes: &[Int]) -> Array2<Float> {
153    let n_samples = y.len();
154    let n_classes = classes.len();
155
156    if n_classes == 2 {
157        // Binary case: convert to {-1, 1}
158        let mut y_bin = Array1::zeros(n_samples);
159        for (i, &label) in y.iter().enumerate() {
160            if label == classes[1] {
161                y_bin[i] = 1.0;
162            } else {
163                y_bin[i] = -1.0;
164            }
165        }
166        y_bin.insert_axis(Axis(1))
167    } else {
168        // Multi-class case: one-vs-all encoding
169        let mut y_bin = Array2::from_elem((n_samples, n_classes), -1.0);
170        for (i, &label) in y.iter().enumerate() {
171            for (j, &class) in classes.iter().enumerate() {
172                if label == class {
173                    y_bin[[i, j]] = 1.0;
174                }
175            }
176        }
177        y_bin
178    }
179}
180
181impl Fit<Array2<Float>, Array1<Int>> for RidgeClassifier<Untrained> {
182    type Fitted = RidgeClassifier<Trained>;
183
184    fn fit(self, x: &Array2<Float>, y: &Array1<Int>) -> Result<Self::Fitted> {
185        let n_samples = x.nrows();
186        let n_features = x.ncols();
187
188        if n_samples != y.len() {
189            return Err(SklearsError::InvalidInput(
190                "X and y must have the same number of samples".to_string(),
191            ));
192        }
193
194        // Get unique classes
195        let mut classes: Vec<Int> = y.iter().copied().collect();
196        classes.sort_unstable();
197        classes.dedup();
198        let n_classes = classes.len();
199
200        if n_classes < 2 {
201            return Err(SklearsError::InvalidInput(
202                "At least two classes are required".to_string(),
203            ));
204        }
205
206        // Convert labels to regression targets
207        let y_bin = label_binarize(y, &classes);
208
209        // Center X and y if fitting intercept
210        let (x_centered, y_centered, x_mean, y_mean) = if self.config.fit_intercept {
211            let x_mean = safe_mean_axis(x, Axis(0))?;
212            let y_mean = safe_mean_axis(&y_bin, Axis(0))?;
213            let x_centered = x - &x_mean;
214            let y_centered = if n_classes == 2 {
215                // For binary case, just center the single column
216                y_bin - y_mean[0]
217            } else {
218                // For multi-class, center each column
219                &y_bin - &y_mean
220            };
221            (x_centered, y_centered, Some(x_mean), Some(y_mean))
222        } else {
223            (x.clone(), y_bin.clone(), None, None)
224        };
225
226        // Solve the ridge regression problem for each class
227        let mut coef = Array2::zeros((n_classes, n_features));
228
229        // Compute X^T X + alpha * I
230        let xt_x = x_centered.t().dot(&x_centered);
231        let xt_x_reg =
232            &xt_x + &(Array2::<Float>::eye(n_features) * self.config.alpha * n_samples as Float);
233
234        if n_classes == 2 {
235            // Binary case: solve once
236            let xt_y = x_centered.t().dot(&y_centered.column(0));
237
238            match xt_x_reg.solve(&xt_y) {
239                Ok(solution) => {
240                    coef.row_mut(0).assign(&(-&solution));
241                    coef.row_mut(1).assign(&solution);
242                }
243                Err(_) => {
244                    return Err(SklearsError::InvalidInput(
245                        "Failed to solve linear system".to_string(),
246                    ));
247                }
248            }
249        } else {
250            // Multi-class case: solve for each class
251            for k in 0..n_classes {
252                let xt_y = x_centered.t().dot(&y_centered.column(k));
253
254                match xt_x_reg.solve(&xt_y) {
255                    Ok(solution) => {
256                        coef.row_mut(k).assign(&solution);
257                    }
258                    Err(_) => {
259                        return Err(SklearsError::InvalidInput(format!(
260                            "Failed to solve linear system for class {}",
261                            k
262                        )));
263                    }
264                }
265            }
266        }
267
268        // Compute intercept if needed
269        let intercept = if self.config.fit_intercept {
270            let x_mean = x_mean.expect("x_mean should be Some when fit_intercept is true");
271            let y_mean = y_mean.expect("y_mean should be Some when fit_intercept is true");
272
273            if n_classes == 2 {
274                // Binary case
275                let intercept_val = y_mean[0] - x_mean.dot(&coef.row(1));
276                Array1::from_vec(vec![-intercept_val, intercept_val])
277            } else {
278                // Multi-class case
279                let mut intercept = Array1::zeros(n_classes);
280                for k in 0..n_classes {
281                    intercept[k] = y_mean[k] - x_mean.dot(&coef.row(k));
282                }
283                intercept
284            }
285        } else {
286            Array1::zeros(n_classes)
287        };
288
289        Ok(RidgeClassifier {
290            config: self.config,
291            state: PhantomData,
292            coef_: Some(coef),
293            intercept_: Some(intercept),
294            classes_: Some(Array1::from_vec(classes)),
295            n_features_in_: Some(n_features),
296        })
297    }
298}
299
300impl Predict<Array2<Float>, Array1<Int>> for RidgeClassifier<Trained> {
301    fn predict(&self, x: &Array2<Float>) -> Result<Array1<Int>> {
302        let coef = self
303            .coef_
304            .as_ref()
305            .expect("coef_ must be Some in Trained state");
306        let intercept = self
307            .intercept_
308            .as_ref()
309            .expect("intercept_ must be Some in Trained state");
310        let classes = self
311            .classes_
312            .as_ref()
313            .expect("classes_ must be Some in Trained state");
314
315        // Compute decision function
316        let scores = x.dot(&coef.t()) + intercept;
317
318        // Predict class with maximum score
319        let mut predictions = Vec::with_capacity(scores.nrows());
320        for row in scores.axis_iter(Axis(0)) {
321            let max_idx = row
322                .iter()
323                .enumerate()
324                .max_by(|(_, a), (_, b)| compare_floats(a, b).unwrap_or(std::cmp::Ordering::Equal))
325                .map(|(idx, _)| idx)
326                .ok_or_else(|| SklearsError::InvalidInput("Empty row in scores".to_string()))?;
327            predictions.push(classes[max_idx]);
328        }
329
330        Ok(Array1::from_vec(predictions))
331    }
332}
333
334impl Score<Array2<Float>, Array1<Int>> for RidgeClassifier<Trained> {
335    type Float = Float;
336
337    fn score(&self, x: &Array2<Float>, y: &Array1<Int>) -> Result<Float> {
338        let predictions = self.predict(x)?;
339        let correct = predictions
340            .iter()
341            .zip(y.iter())
342            .filter(|(pred, true_val)| pred == true_val)
343            .count();
344
345        Ok(correct as Float / y.len() as Float)
346    }
347}
348
349impl RidgeClassifier<Trained> {
350    /// Get the coefficients
351    pub fn coef(&self) -> &Array2<Float> {
352        self.coef_
353            .as_ref()
354            .expect("coef_ must be Some in Trained state")
355    }
356
357    /// Get the intercept
358    pub fn intercept(&self) -> Option<&Array1<Float>> {
359        self.intercept_.as_ref()
360    }
361
362    /// Get the classes
363    pub fn classes(&self) -> &Array1<Int> {
364        self.classes_
365            .as_ref()
366            .expect("classes_ must be Some in Trained state")
367    }
368
369    /// Get the number of features seen during fit
370    pub fn n_features_in(&self) -> usize {
371        self.n_features_in_
372            .expect("n_features_in_ must be Some in Trained state")
373    }
374
375    /// Get decision function (raw scores)
376    pub fn decision_function(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
377        let coef = self
378            .coef_
379            .as_ref()
380            .expect("coef_ must be Some in Trained state");
381        let intercept = self
382            .intercept_
383            .as_ref()
384            .expect("intercept_ must be Some in Trained state");
385
386        Ok(x.dot(&coef.t()) + intercept)
387    }
388}
389
390#[allow(non_snake_case)]
391#[cfg(test)]
392mod tests {
393    use super::*;
394
395    use scirs2_core::ndarray::array;
396
397    #[test]
398    fn test_ridge_classifier_binary() {
399        // Simple linearly separable data
400        let x = array![
401            [1.0, 1.0],
402            [2.0, 2.0],
403            [3.0, 3.0],
404            [-1.0, -1.0],
405            [-2.0, -2.0],
406            [-3.0, -3.0],
407        ];
408        let y = array![1, 1, 1, 0, 0, 0];
409
410        let model = RidgeClassifier::new()
411            .alpha(1.0)
412            .fit(&x, &y)
413            .expect("model fitting should succeed");
414
415        let _predictions = model.predict(&x).expect("prediction should succeed");
416        let accuracy = model.score(&x, &y).expect("scoring should succeed");
417
418        // Should achieve good classification on this simple data
419        assert!(accuracy > 0.8);
420
421        // Check binary class structure
422        assert_eq!(model.classes().len(), 2);
423        assert_eq!(model.coef().nrows(), 2);
424    }
425
426    #[test]
427    fn test_ridge_classifier_multiclass() {
428        let x = array![
429            [1.0, 1.0],
430            [2.0, 2.0],
431            [-1.0, -1.0],
432            [-2.0, -2.0],
433            [1.0, -1.0],
434            [2.0, -2.0],
435        ];
436        let y = array![0, 0, 1, 1, 2, 2];
437
438        let model = RidgeClassifier::new()
439            .alpha(0.1)
440            .fit(&x, &y)
441            .expect("model fitting should succeed");
442
443        let accuracy = model.score(&x, &y).expect("scoring should succeed");
444        assert!(accuracy > 0.8);
445
446        // Check that we have the right number of classes
447        assert_eq!(model.classes().len(), 3);
448        assert_eq!(model.coef().nrows(), 3);
449    }
450
451    #[test]
452    fn test_ridge_classifier_no_intercept() {
453        let x = array![[1.0, 1.0], [2.0, 2.0], [-1.0, -1.0], [-2.0, -2.0],];
454        let y = array![1, 1, 0, 0];
455
456        let model = RidgeClassifier::new()
457            .fit_intercept(false)
458            .fit(&x, &y)
459            .expect("operation should succeed");
460
461        let intercept = model.intercept().expect("intercept should be available");
462        assert!(intercept.iter().all(|&v| v == 0.0));
463    }
464
465    #[test]
466    fn test_ridge_classifier_strong_regularization() {
467        let x = array![[1.0, 0.0], [2.0, 0.0], [0.0, 1.0], [0.0, 2.0],];
468        let y = array![0, 0, 1, 1];
469
470        // With very high alpha, coefficients should be small
471        let model = RidgeClassifier::new()
472            .alpha(1000.0)
473            .fit(&x, &y)
474            .expect("model fitting should succeed");
475
476        let coef = model.coef();
477        assert!(coef.iter().all(|&c| c.abs() < 0.1));
478    }
479
480    #[test]
481    fn test_ridge_classifier_decision_function() {
482        let x = array![[1.0, 1.0], [-1.0, -1.0],];
483        let y = array![1, 0];
484
485        let model = RidgeClassifier::new()
486            .fit(&x, &y)
487            .expect("model fitting should succeed");
488
489        let decision = model
490            .decision_function(&x)
491            .expect("operation should succeed");
492
493        // For binary classification, we should have 2 columns
494        assert_eq!(decision.ncols(), 2);
495
496        // The predicted class should have the highest score
497        let predictions = model.predict(&x).expect("prediction should succeed");
498        for (i, &pred) in predictions.iter().enumerate() {
499            let scores = decision.row(i);
500            let max_idx = scores
501                .iter()
502                .enumerate()
503                .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("operation should succeed"))
504                .map(|(idx, _)| idx)
505                .expect("operation should succeed");
506            assert_eq!(model.classes()[max_idx], pred);
507        }
508    }
509
510    #[test]
511    fn test_label_binarize() {
512        // Test binary case
513        let y = array![0, 1, 1, 0];
514        let classes = vec![0, 1];
515        let y_bin = label_binarize(&y, &classes);
516
517        assert_eq!(y_bin.shape(), &[4, 1]);
518        assert_eq!(y_bin[[0, 0]], -1.0);
519        assert_eq!(y_bin[[1, 0]], 1.0);
520
521        // Test multi-class case
522        let y = array![0, 1, 2, 0];
523        let classes = vec![0, 1, 2];
524        let y_bin = label_binarize(&y, &classes);
525
526        assert_eq!(y_bin.shape(), &[4, 3]);
527        assert_eq!(y_bin[[0, 0]], 1.0);
528        assert_eq!(y_bin[[0, 1]], -1.0);
529        assert_eq!(y_bin[[2, 2]], 1.0);
530    }
531}