smartcore/linear/
logistic_regression.rs

1//! # Logistic Regression
2//!
3//! As [Linear Regression](../linear_regression/index.html), logistic regression explains your outcome as a linear combination of predictor variables \\(X\\) but rather than modeling this response directly,
4//! logistic regression models the probability that \\(y\\) belongs to a particular category, \\(Pr(y = 1|X) \\), as:
5//!
6//! \\[ Pr(y=1) \approx \frac{e^{\beta_0 + \sum_{i=1}^n \beta_iX_i}}{1 + e^{\beta_0 + \sum_{i=1}^n \beta_iX_i}} \\]
7//!
8//! `smartcore` uses [limited memory BFGS](https://en.wikipedia.org/wiki/Limited-memory_BFGS) method to find estimates of regression coefficients, \\(\beta\\)
9//!
10//! Example:
11//!
12//! ```
13//! use smartcore::linalg::basic::matrix::DenseMatrix;
14//! use smartcore::linear::logistic_regression::*;
15//!
16//! //Iris data
17//! let x = DenseMatrix::from_2d_array(&[
18//!           &[5.1, 3.5, 1.4, 0.2],
19//!           &[4.9, 3.0, 1.4, 0.2],
20//!           &[4.7, 3.2, 1.3, 0.2],
21//!           &[4.6, 3.1, 1.5, 0.2],
22//!           &[5.0, 3.6, 1.4, 0.2],
23//!           &[5.4, 3.9, 1.7, 0.4],
24//!           &[4.6, 3.4, 1.4, 0.3],
25//!           &[5.0, 3.4, 1.5, 0.2],
26//!           &[4.4, 2.9, 1.4, 0.2],
27//!           &[4.9, 3.1, 1.5, 0.1],
28//!           &[7.0, 3.2, 4.7, 1.4],
29//!           &[6.4, 3.2, 4.5, 1.5],
30//!           &[6.9, 3.1, 4.9, 1.5],
31//!           &[5.5, 2.3, 4.0, 1.3],
32//!           &[6.5, 2.8, 4.6, 1.5],
33//!           &[5.7, 2.8, 4.5, 1.3],
34//!           &[6.3, 3.3, 4.7, 1.6],
35//!           &[4.9, 2.4, 3.3, 1.0],
36//!           &[6.6, 2.9, 4.6, 1.3],
37//!           &[5.2, 2.7, 3.9, 1.4],
38//!           ]).unwrap();
39//! let y: Vec<i32> = vec![
40//!           0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
41//! ];
42//!
43//! let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
44//!
45//! let y_hat = lr.predict(&x).unwrap();
46//! ```
47//!
48//! ## References:
49//! * ["Pattern Recognition and Machine Learning", C.M. Bishop, Linear Models for Classification](https://www.microsoft.com/en-us/research/uploads/prod/2006/01/Bishop-Pattern-Recognition-and-Machine-Learning-2006.pdf)
50//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 4.3 Logistic Regression](http://faculty.marshall.usc.edu/gareth-james/ISL/)
51//! * ["On the Limited Memory Method for Large Scale Optimization", Nocedal et al., Mathematical Programming, 1989](http://users.iems.northwestern.edu/~nocedal/PDFfiles/limited.pdf)
52//!
53//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
54//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
55use std::cmp::Ordering;
56use std::fmt::Debug;
57use std::marker::PhantomData;
58
59#[cfg(feature = "serde")]
60use serde::{Deserialize, Serialize};
61
62use crate::api::{Predictor, SupervisedEstimator};
63use crate::error::Failed;
64use crate::linalg::basic::arrays::{Array1, Array2, MutArrayView1};
65use crate::numbers::basenum::Number;
66use crate::numbers::floatnum::FloatNumber;
67use crate::numbers::realnum::RealNumber;
68use crate::optimization::first_order::lbfgs::LBFGS;
69use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
70use crate::optimization::line_search::Backtracking;
71use crate::optimization::FunctionOrder;
72
73#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
74#[derive(Debug, Clone, Eq, PartialEq, Default)]
75/// Solver options for Logistic regression. Right now only LBFGS solver is supported.
76pub enum LogisticRegressionSolverName {
77    /// Limited-memory Broyden–Fletcher–Goldfarb–Shanno method, see [LBFGS paper](http://users.iems.northwestern.edu/~nocedal/lbfgsb.html)
78    #[default]
79    LBFGS,
80}
81
82/// Logistic Regression parameters
83#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
84#[derive(Debug, Clone)]
85pub struct LogisticRegressionParameters<T: Number + FloatNumber> {
86    #[cfg_attr(feature = "serde", serde(default))]
87    /// Solver to use for estimation of regression coefficients.
88    pub solver: LogisticRegressionSolverName,
89    #[cfg_attr(feature = "serde", serde(default))]
90    /// Regularization parameter.
91    pub alpha: T,
92}
93
94/// Logistic Regression grid search parameters
95#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
96#[derive(Debug, Clone)]
97pub struct LogisticRegressionSearchParameters<T: Number> {
98    #[cfg_attr(feature = "serde", serde(default))]
99    /// Solver to use for estimation of regression coefficients.
100    pub solver: Vec<LogisticRegressionSolverName>,
101    #[cfg_attr(feature = "serde", serde(default))]
102    /// Regularization parameter.
103    pub alpha: Vec<T>,
104}
105
106/// Logistic Regression grid search iterator
107pub struct LogisticRegressionSearchParametersIterator<T: Number> {
108    logistic_regression_search_parameters: LogisticRegressionSearchParameters<T>,
109    current_solver: usize,
110    current_alpha: usize,
111}
112
113impl<T: Number + FloatNumber> IntoIterator for LogisticRegressionSearchParameters<T> {
114    type Item = LogisticRegressionParameters<T>;
115    type IntoIter = LogisticRegressionSearchParametersIterator<T>;
116
117    fn into_iter(self) -> Self::IntoIter {
118        LogisticRegressionSearchParametersIterator {
119            logistic_regression_search_parameters: self,
120            current_solver: 0,
121            current_alpha: 0,
122        }
123    }
124}
125
126impl<T: Number + FloatNumber> Iterator for LogisticRegressionSearchParametersIterator<T> {
127    type Item = LogisticRegressionParameters<T>;
128
129    fn next(&mut self) -> Option<Self::Item> {
130        if self.current_alpha == self.logistic_regression_search_parameters.alpha.len()
131            && self.current_solver == self.logistic_regression_search_parameters.solver.len()
132        {
133            return None;
134        }
135
136        let next = LogisticRegressionParameters {
137            solver: self.logistic_regression_search_parameters.solver[self.current_solver].clone(),
138            alpha: self.logistic_regression_search_parameters.alpha[self.current_alpha],
139        };
140
141        if self.current_alpha + 1 < self.logistic_regression_search_parameters.alpha.len() {
142            self.current_alpha += 1;
143        } else if self.current_solver + 1 < self.logistic_regression_search_parameters.solver.len()
144        {
145            self.current_alpha = 0;
146            self.current_solver += 1;
147        } else {
148            self.current_alpha += 1;
149            self.current_solver += 1;
150        }
151
152        Some(next)
153    }
154}
155
156impl<T: Number + FloatNumber> Default for LogisticRegressionSearchParameters<T> {
157    fn default() -> Self {
158        let default_params = LogisticRegressionParameters::default();
159
160        LogisticRegressionSearchParameters {
161            solver: vec![default_params.solver],
162            alpha: vec![default_params.alpha],
163        }
164    }
165}
166
167/// Logistic Regression
168#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
169#[derive(Debug)]
170pub struct LogisticRegression<
171    TX: Number + FloatNumber + RealNumber,
172    TY: Number + Ord,
173    X: Array2<TX>,
174    Y: Array1<TY>,
175> {
176    coefficients: Option<X>,
177    intercept: Option<X>,
178    classes: Option<Vec<TY>>,
179    num_attributes: usize,
180    num_classes: usize,
181    _phantom_tx: PhantomData<TX>,
182    _phantom_y: PhantomData<Y>,
183}
184
185trait ObjectiveFunction<T: Number + FloatNumber, X: Array2<T>> {
186    fn f(&self, w_bias: &[T]) -> T;
187
188    #[allow(clippy::ptr_arg)]
189    fn df(&self, g: &mut Vec<T>, w_bias: &Vec<T>);
190
191    #[allow(clippy::ptr_arg)]
192    fn partial_dot(w: &[T], x: &X, v_col: usize, m_row: usize) -> T {
193        let mut sum = T::zero();
194        let p = x.shape().1;
195        for i in 0..p {
196            sum += *x.get((m_row, i)) * w[i + v_col];
197        }
198
199        sum + w[p + v_col]
200    }
201}
202
203struct BinaryObjectiveFunction<'a, T: Number + FloatNumber, X: Array2<T>> {
204    x: &'a X,
205    y: Vec<usize>,
206    alpha: T,
207    _phantom_t: PhantomData<T>,
208}
209
210impl<T: Number + FloatNumber> LogisticRegressionParameters<T> {
211    /// Solver to use for estimation of regression coefficients.
212    pub fn with_solver(mut self, solver: LogisticRegressionSolverName) -> Self {
213        self.solver = solver;
214        self
215    }
216    /// Regularization parameter.
217    pub fn with_alpha(mut self, alpha: T) -> Self {
218        self.alpha = alpha;
219        self
220    }
221}
222
223impl<T: Number + FloatNumber> Default for LogisticRegressionParameters<T> {
224    fn default() -> Self {
225        LogisticRegressionParameters {
226            solver: LogisticRegressionSolverName::default(),
227            alpha: T::zero(),
228        }
229    }
230}
231
232impl<TX: Number + FloatNumber + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
233    PartialEq for LogisticRegression<TX, TY, X, Y>
234{
235    fn eq(&self, other: &Self) -> bool {
236        if self.num_classes != other.num_classes
237            || self.num_attributes != other.num_attributes
238            || self.classes().len() != other.classes().len()
239        {
240            false
241        } else {
242            for i in 0..self.classes().len() {
243                if self.classes()[i] != other.classes()[i] {
244                    return false;
245                }
246            }
247
248            self.coefficients()
249                .iterator(0)
250                .zip(other.coefficients().iterator(0))
251                .all(|(&a, &b)| (a - b).abs() <= TX::epsilon())
252                && self
253                    .intercept()
254                    .iterator(0)
255                    .zip(other.intercept().iterator(0))
256                    .all(|(&a, &b)| (a - b).abs() <= TX::epsilon())
257        }
258    }
259}
260
261impl<T: Number + FloatNumber, X: Array2<T>> ObjectiveFunction<T, X>
262    for BinaryObjectiveFunction<'_, T, X>
263{
264    fn f(&self, w_bias: &[T]) -> T {
265        let mut f = T::zero();
266        let (n, p) = self.x.shape();
267
268        for i in 0..n {
269            let wx = BinaryObjectiveFunction::partial_dot(w_bias, self.x, 0, i);
270            f += wx.ln_1pe() - (T::from(self.y[i]).unwrap()) * wx;
271        }
272
273        if self.alpha > T::zero() {
274            let mut w_squared = T::zero();
275            for w_bias_i in w_bias.iter().take(p) {
276                w_squared += *w_bias_i * *w_bias_i;
277            }
278            f += T::from_f64(0.5).unwrap() * self.alpha * w_squared;
279        }
280
281        f
282    }
283
284    fn df(&self, g: &mut Vec<T>, w_bias: &Vec<T>) {
285        g.copy_from(&Vec::zeros(g.len()));
286
287        let (n, p) = self.x.shape();
288
289        for i in 0..n {
290            let wx = BinaryObjectiveFunction::partial_dot(w_bias, self.x, 0, i);
291
292            let dyi = (T::from(self.y[i]).unwrap()) - wx.sigmoid();
293            for (j, g_j) in g.iter_mut().enumerate().take(p) {
294                *g_j -= dyi * *self.x.get((i, j));
295            }
296            g[p] -= dyi;
297        }
298
299        if self.alpha > T::zero() {
300            for i in 0..p {
301                let w = w_bias[i];
302                g[i] += self.alpha * w;
303            }
304        }
305    }
306}
307
308struct MultiClassObjectiveFunction<'a, T: Number + FloatNumber, X: Array2<T>> {
309    x: &'a X,
310    y: Vec<usize>,
311    k: usize,
312    alpha: T,
313    _phantom_t: PhantomData<T>,
314}
315
316impl<T: Number + FloatNumber + RealNumber, X: Array2<T>> ObjectiveFunction<T, X>
317    for MultiClassObjectiveFunction<'_, T, X>
318{
319    fn f(&self, w_bias: &[T]) -> T {
320        let mut f = T::zero();
321        let mut prob = vec![T::zero(); self.k];
322        let (n, p) = self.x.shape();
323        for i in 0..n {
324            for (j, prob_j) in prob.iter_mut().enumerate().take(self.k) {
325                *prob_j = MultiClassObjectiveFunction::partial_dot(w_bias, self.x, j * (p + 1), i);
326            }
327            prob.softmax_mut();
328            f -= prob[self.y[i]].ln();
329        }
330
331        if self.alpha > T::zero() {
332            let mut w_squared = T::zero();
333            for i in 0..self.k {
334                for j in 0..p {
335                    let wi = w_bias[i * (p + 1) + j];
336                    w_squared += wi * wi;
337                }
338            }
339            f += T::from_f64(0.5).unwrap() * self.alpha * w_squared;
340        }
341
342        f
343    }
344
345    fn df(&self, g: &mut Vec<T>, w: &Vec<T>) {
346        g.copy_from(&Vec::zeros(g.len()));
347
348        let mut prob = vec![T::zero(); self.k];
349        let (n, p) = self.x.shape();
350
351        for i in 0..n {
352            for (j, prob_j) in prob.iter_mut().enumerate().take(self.k) {
353                *prob_j = MultiClassObjectiveFunction::partial_dot(w, self.x, j * (p + 1), i);
354            }
355
356            prob.softmax_mut();
357
358            for j in 0..self.k {
359                let yi = (if self.y[i] == j { T::one() } else { T::zero() }) - prob[j];
360
361                for l in 0..p {
362                    let pos = j * (p + 1);
363                    g[pos + l] -= yi * *self.x.get((i, l));
364                }
365                g[j * (p + 1) + p] -= yi;
366            }
367        }
368
369        if self.alpha > T::zero() {
370            for i in 0..self.k {
371                for j in 0..p {
372                    let pos = i * (p + 1);
373                    let wi = w[pos + j];
374                    g[pos + j] += self.alpha * wi;
375                }
376            }
377        }
378    }
379}
380
381impl<TX: Number + FloatNumber + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
382    SupervisedEstimator<X, Y, LogisticRegressionParameters<TX>>
383    for LogisticRegression<TX, TY, X, Y>
384{
385    fn new() -> Self {
386        Self {
387            coefficients: Option::None,
388            intercept: Option::None,
389            classes: Option::None,
390            num_attributes: 0,
391            num_classes: 0,
392            _phantom_tx: PhantomData,
393            _phantom_y: PhantomData,
394        }
395    }
396
397    fn fit(x: &X, y: &Y, parameters: LogisticRegressionParameters<TX>) -> Result<Self, Failed> {
398        LogisticRegression::fit(x, y, parameters)
399    }
400}
401
402impl<TX: Number + FloatNumber + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
403    Predictor<X, Y> for LogisticRegression<TX, TY, X, Y>
404{
405    fn predict(&self, x: &X) -> Result<Y, Failed> {
406        self.predict(x)
407    }
408}
409
410impl<TX: Number + FloatNumber + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
411    LogisticRegression<TX, TY, X, Y>
412{
413    /// Fits Logistic Regression to your data.
414    /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
415    /// * `y` - target class values
416    /// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
417    pub fn fit(
418        x: &X,
419        y: &Y,
420        parameters: LogisticRegressionParameters<TX>,
421    ) -> Result<LogisticRegression<TX, TY, X, Y>, Failed> {
422        let (x_nrows, num_attributes) = x.shape();
423        let y_nrows = y.shape();
424
425        if x_nrows != y_nrows {
426            return Err(Failed::fit(
427                "Number of rows of X doesn\'t match number of rows of Y",
428            ));
429        }
430
431        let classes = y.unique();
432
433        let k = classes.len();
434
435        let mut yi: Vec<usize> = vec![0; y_nrows];
436
437        for (i, yi_i) in yi.iter_mut().enumerate().take(y_nrows) {
438            let yc = y.get(i);
439            *yi_i = classes.iter().position(|c| yc == c).unwrap();
440        }
441
442        match k.cmp(&2) {
443            Ordering::Less => Err(Failed::fit(&format!(
444                "incorrect number of classes: {k}. Should be >= 2."
445            ))),
446            Ordering::Equal => {
447                let x0 = Vec::zeros(num_attributes + 1);
448
449                let objective = BinaryObjectiveFunction {
450                    x,
451                    y: yi,
452                    alpha: parameters.alpha,
453                    _phantom_t: PhantomData,
454                };
455
456                let result = Self::minimize(x0, objective);
457
458                let weights = X::from_iterator(result.x.into_iter(), 1, num_attributes + 1, 0);
459                let coefficients = weights.slice(0..1, 0..num_attributes);
460                let intercept = weights.slice(0..1, num_attributes..num_attributes + 1);
461
462                Ok(LogisticRegression {
463                    coefficients: Some(X::from_slice(coefficients.as_ref())),
464                    intercept: Some(X::from_slice(intercept.as_ref())),
465                    classes: Some(classes),
466                    num_attributes,
467                    num_classes: k,
468                    _phantom_tx: PhantomData,
469                    _phantom_y: PhantomData,
470                })
471            }
472            Ordering::Greater => {
473                let x0 = Vec::zeros((num_attributes + 1) * k);
474
475                let objective = MultiClassObjectiveFunction {
476                    x,
477                    y: yi,
478                    k,
479                    alpha: parameters.alpha,
480                    _phantom_t: PhantomData,
481                };
482
483                let result = Self::minimize(x0, objective);
484                let weights = X::from_iterator(result.x.into_iter(), k, num_attributes + 1, 0);
485                let coefficients = weights.slice(0..k, 0..num_attributes);
486                let intercept = weights.slice(0..k, num_attributes..num_attributes + 1);
487
488                Ok(LogisticRegression {
489                    coefficients: Some(X::from_slice(coefficients.as_ref())),
490                    intercept: Some(X::from_slice(intercept.as_ref())),
491                    classes: Some(classes),
492                    num_attributes,
493                    num_classes: k,
494                    _phantom_tx: PhantomData,
495                    _phantom_y: PhantomData,
496                })
497            }
498        }
499    }
500
501    /// Predict class labels for samples in `x`.
502    /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
503    pub fn predict(&self, x: &X) -> Result<Y, Failed> {
504        let n = x.shape().0;
505        let mut result = Y::zeros(n);
506        if self.num_classes == 2 {
507            let y_hat = x.ab(false, self.coefficients(), true);
508            let intercept = *self.intercept().get((0, 0));
509            for (i, y_hat_i) in y_hat.iterator(0).enumerate().take(n) {
510                result.set(
511                    i,
512                    self.classes()[usize::from(
513                        RealNumber::sigmoid(*y_hat_i + intercept) > RealNumber::half(),
514                    )],
515                );
516            }
517        } else {
518            let mut y_hat = x.matmul(&self.coefficients().transpose());
519            for r in 0..n {
520                for c in 0..self.num_classes {
521                    y_hat.set((r, c), *y_hat.get((r, c)) + *self.intercept().get((c, 0)));
522                }
523            }
524            let class_idxs = y_hat.argmax(1);
525            for (i, class_i) in class_idxs.iter().enumerate().take(n) {
526                result.set(i, self.classes()[*class_i]);
527            }
528        }
529        Ok(result)
530    }
531
532    /// Get estimates regression coefficients, this create a sharable reference
533    pub fn coefficients(&self) -> &X {
534        self.coefficients.as_ref().unwrap()
535    }
536
537    /// Get estimate of intercept, this create a sharable reference
538    pub fn intercept(&self) -> &X {
539        self.intercept.as_ref().unwrap()
540    }
541
542    /// Get classes, this create a sharable reference
543    pub fn classes(&self) -> &Vec<TY> {
544        self.classes.as_ref().unwrap()
545    }
546
547    fn minimize(
548        x0: Vec<TX>,
549        objective: impl ObjectiveFunction<TX, X>,
550    ) -> OptimizerResult<TX, Vec<TX>> {
551        let f = |w: &Vec<TX>| -> TX { objective.f(w) };
552
553        let df = |g: &mut Vec<TX>, w: &Vec<TX>| objective.df(g, w);
554
555        let ls: Backtracking<TX> = Backtracking {
556            order: FunctionOrder::THIRD,
557            ..Default::default()
558        };
559        let optimizer: LBFGS = Default::default();
560
561        optimizer.optimize(&f, &df, &x0, &ls)
562    }
563}
564
565#[cfg(test)]
566mod tests {
567    use super::*;
568
569    #[cfg(feature = "datasets")]
570    use crate::dataset::generator::make_blobs;
571    use crate::linalg::basic::arrays::Array;
572    use crate::linalg::basic::matrix::DenseMatrix;
573
574    #[test]
575    fn search_parameters() {
576        let parameters = LogisticRegressionSearchParameters {
577            alpha: vec![0., 1.],
578            ..Default::default()
579        };
580        let mut iter = parameters.into_iter();
581        assert_eq!(iter.next().unwrap().alpha, 0.);
582        assert_eq!(
583            iter.next().unwrap().solver,
584            LogisticRegressionSolverName::LBFGS
585        );
586        assert!(iter.next().is_none());
587    }
588
589    #[cfg_attr(
590        all(target_arch = "wasm32", not(target_os = "wasi")),
591        wasm_bindgen_test::wasm_bindgen_test
592    )]
593    #[test]
594    fn multiclass_objective_f() {
595        let x = DenseMatrix::from_2d_array(&[
596            &[1., -5.],
597            &[2., 5.],
598            &[3., -2.],
599            &[1., 2.],
600            &[2., 0.],
601            &[6., -5.],
602            &[7., 5.],
603            &[6., -2.],
604            &[7., 2.],
605            &[6., 0.],
606            &[8., -5.],
607            &[9., 5.],
608            &[10., -2.],
609            &[8., 2.],
610            &[9., 0.],
611        ])
612        .unwrap();
613
614        let y = vec![0, 0, 1, 1, 2, 1, 1, 0, 0, 2, 1, 1, 0, 0, 1];
615
616        let objective = MultiClassObjectiveFunction {
617            x: &x,
618            y: y.clone(),
619            k: 3,
620            alpha: 0.0,
621            _phantom_t: PhantomData,
622        };
623
624        let mut g = vec![0f64; 9];
625
626        objective.df(&mut g, &vec![1., 2., 3., 4., 5., 6., 7., 8., 9.]);
627        objective.df(&mut g, &vec![1., 2., 3., 4., 5., 6., 7., 8., 9.]);
628
629        assert!((g[0] + 33.000068218163484).abs() < f64::EPSILON);
630
631        let f = objective.f(&[1., 2., 3., 4., 5., 6., 7., 8., 9.]);
632
633        assert!((f - 408.0052230582765).abs() < f64::EPSILON);
634
635        let objective_reg = MultiClassObjectiveFunction {
636            x: &x,
637            y,
638            k: 3,
639            alpha: 1.0,
640            _phantom_t: PhantomData,
641        };
642
643        let f = objective_reg.f(&[1., 2., 3., 4., 5., 6., 7., 8., 9.]);
644        assert!((f - 487.5052).abs() < 1e-4);
645
646        objective_reg.df(&mut g, &vec![1., 2., 3., 4., 5., 6., 7., 8., 9.]);
647        assert!((g[0].abs() - 32.0).abs() < 1e-4);
648    }
649
650    #[cfg_attr(
651        all(target_arch = "wasm32", not(target_os = "wasi")),
652        wasm_bindgen_test::wasm_bindgen_test
653    )]
654    #[test]
655    fn binary_objective_f() {
656        let x = DenseMatrix::from_2d_array(&[
657            &[1., -5.],
658            &[2., 5.],
659            &[3., -2.],
660            &[1., 2.],
661            &[2., 0.],
662            &[6., -5.],
663            &[7., 5.],
664            &[6., -2.],
665            &[7., 2.],
666            &[6., 0.],
667            &[8., -5.],
668            &[9., 5.],
669            &[10., -2.],
670            &[8., 2.],
671            &[9., 0.],
672        ])
673        .unwrap();
674
675        let y = vec![0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1];
676
677        let objective = BinaryObjectiveFunction {
678            x: &x,
679            y: y.clone(),
680            alpha: 0.0,
681            _phantom_t: PhantomData,
682        };
683
684        let mut g = vec![0f64; 3];
685
686        objective.df(&mut g, &vec![1., 2., 3.]);
687        objective.df(&mut g, &vec![1., 2., 3.]);
688
689        assert!((g[0] - 26.051064349381285).abs() < f64::EPSILON);
690        assert!((g[1] - 10.239000702928523).abs() < f64::EPSILON);
691        assert!((g[2] - 3.869294270156324).abs() < f64::EPSILON);
692
693        let f = objective.f(&[1., 2., 3.]);
694
695        assert!((f - 59.76994756647412).abs() < f64::EPSILON);
696
697        let objective_reg = BinaryObjectiveFunction {
698            x: &x,
699            y,
700            alpha: 1.0,
701            _phantom_t: PhantomData,
702        };
703
704        let f = objective_reg.f(&[1., 2., 3.]);
705        assert!((f - 62.2699).abs() < 1e-4);
706
707        objective_reg.df(&mut g, &vec![1., 2., 3.]);
708        assert!((g[0] - 27.0511).abs() < 1e-4);
709        assert!((g[1] - 12.239).abs() < 1e-4);
710        assert!((g[2] - 3.8693).abs() < 1e-4);
711    }
712
713    #[cfg_attr(
714        all(target_arch = "wasm32", not(target_os = "wasi")),
715        wasm_bindgen_test::wasm_bindgen_test
716    )]
717    #[test]
718    fn lr_fit_predict() {
719        let x: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
720            &[1., -5.],
721            &[2., 5.],
722            &[3., -2.],
723            &[1., 2.],
724            &[2., 0.],
725            &[6., -5.],
726            &[7., 5.],
727            &[6., -2.],
728            &[7., 2.],
729            &[6., 0.],
730            &[8., -5.],
731            &[9., 5.],
732            &[10., -2.],
733            &[8., 2.],
734            &[9., 0.],
735        ])
736        .unwrap();
737        let y: Vec<i32> = vec![0, 0, 1, 1, 2, 1, 1, 0, 0, 2, 1, 1, 0, 0, 1];
738
739        let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
740
741        assert_eq!(lr.coefficients().shape(), (3, 2));
742        assert_eq!(lr.intercept().shape(), (3, 1));
743
744        assert!((*lr.coefficients().get((0, 0)) - 0.0435).abs() < 1e-4);
745        assert!(
746            (*lr.intercept().get((0, 0)) - 0.1250).abs() < 1e-4,
747            "expected to be least than 1e-4, got {}",
748            (*lr.intercept().get((0, 0)) - 0.1250).abs()
749        );
750
751        let y_hat = lr.predict(&x).unwrap();
752
753        assert_eq!(y_hat, vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]);
754    }
755
756    #[cfg(feature = "datasets")]
757    #[cfg_attr(
758        all(target_arch = "wasm32", not(target_os = "wasi")),
759        wasm_bindgen_test::wasm_bindgen_test
760    )]
761    #[test]
762    fn lr_fit_predict_multiclass() {
763        let blobs = make_blobs(15, 4, 3);
764
765        let x: DenseMatrix<f32> = DenseMatrix::from_iterator(blobs.data.into_iter(), 15, 4, 0);
766        let y: Vec<i32> = blobs.target.into_iter().map(|v| v as i32).collect();
767
768        let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
769
770        let y_hat = lr.predict(&x).unwrap();
771
772        assert_eq!(y_hat, vec![0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]);
773
774        let lr_reg = LogisticRegression::fit(
775            &x,
776            &y,
777            LogisticRegressionParameters::default().with_alpha(10.0),
778        )
779        .unwrap();
780
781        let reg_coeff_sum: f32 = lr_reg.coefficients().abs().iter().sum();
782        let coeff: f32 = lr.coefficients().abs().iter().sum();
783
784        assert!(reg_coeff_sum < coeff);
785    }
786
787    #[cfg(feature = "datasets")]
788    #[cfg_attr(
789        all(target_arch = "wasm32", not(target_os = "wasi")),
790        wasm_bindgen_test::wasm_bindgen_test
791    )]
792    #[test]
793    fn lr_fit_predict_binary() {
794        let blobs = make_blobs(20, 4, 2);
795
796        let x = DenseMatrix::from_iterator(blobs.data.into_iter(), 20, 4, 0);
797        let y: Vec<i32> = blobs.target.into_iter().map(|v| v as i32).collect();
798
799        let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
800
801        let y_hat = lr.predict(&x).unwrap();
802
803        assert_eq!(
804            y_hat,
805            vec![0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
806        );
807
808        let lr_reg = LogisticRegression::fit(
809            &x,
810            &y,
811            LogisticRegressionParameters::default().with_alpha(10.0),
812        )
813        .unwrap();
814
815        let reg_coeff_sum: f32 = lr_reg.coefficients().abs().iter().sum();
816        let coeff: f32 = lr.coefficients().abs().iter().sum();
817
818        assert!(reg_coeff_sum < coeff);
819    }
820
821    //TODO: serialization for the new DenseMatrix needs to be implemented
822    #[cfg_attr(
823        all(target_arch = "wasm32", not(target_os = "wasi")),
824        wasm_bindgen_test::wasm_bindgen_test
825    )]
826    #[test]
827    #[cfg(feature = "serde")]
828    fn serde() {
829        let x: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
830            &[1., -5.],
831            &[2., 5.],
832            &[3., -2.],
833            &[1., 2.],
834            &[2., 0.],
835            &[6., -5.],
836            &[7., 5.],
837            &[6., -2.],
838            &[7., 2.],
839            &[6., 0.],
840            &[8., -5.],
841            &[9., 5.],
842            &[10., -2.],
843            &[8., 2.],
844            &[9., 0.],
845        ])
846        .unwrap();
847        let y: Vec<i32> = vec![0, 0, 1, 1, 2, 1, 1, 0, 0, 2, 1, 1, 0, 0, 1];
848
849        let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
850
851        let deserialized_lr: LogisticRegression<f64, i32, DenseMatrix<f64>, Vec<i32>> =
852            serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
853
854        assert_eq!(lr, deserialized_lr);
855    }
856
857    #[cfg_attr(
858        all(target_arch = "wasm32", not(target_os = "wasi")),
859        wasm_bindgen_test::wasm_bindgen_test
860    )]
861    #[test]
862    fn lr_fit_predict_iris() {
863        let x = DenseMatrix::from_2d_array(&[
864            &[5.1, 3.5, 1.4, 0.2],
865            &[4.9, 3.0, 1.4, 0.2],
866            &[4.7, 3.2, 1.3, 0.2],
867            &[4.6, 3.1, 1.5, 0.2],
868            &[5.0, 3.6, 1.4, 0.2],
869            &[5.4, 3.9, 1.7, 0.4],
870            &[4.6, 3.4, 1.4, 0.3],
871            &[5.0, 3.4, 1.5, 0.2],
872            &[4.4, 2.9, 1.4, 0.2],
873            &[4.9, 3.1, 1.5, 0.1],
874            &[7.0, 3.2, 4.7, 1.4],
875            &[6.4, 3.2, 4.5, 1.5],
876            &[6.9, 3.1, 4.9, 1.5],
877            &[5.5, 2.3, 4.0, 1.3],
878            &[6.5, 2.8, 4.6, 1.5],
879            &[5.7, 2.8, 4.5, 1.3],
880            &[6.3, 3.3, 4.7, 1.6],
881            &[4.9, 2.4, 3.3, 1.0],
882            &[6.6, 2.9, 4.6, 1.3],
883            &[5.2, 2.7, 3.9, 1.4],
884        ])
885        .unwrap();
886        let y: Vec<i32> = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
887
888        let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
889        let lr_reg = LogisticRegression::fit(
890            &x,
891            &y,
892            LogisticRegressionParameters::default().with_alpha(1.0),
893        )
894        .unwrap();
895
896        let y_hat = lr.predict(&x).unwrap();
897
898        let error: i32 = y.into_iter().zip(y_hat).map(|(a, b)| (a - b).abs()).sum();
899
900        assert!(error <= 1);
901
902        let reg_coeff_sum: f32 = lr_reg.coefficients().abs().iter().sum();
903        let coeff: f32 = lr.coefficients().abs().iter().sum();
904
905        assert!(reg_coeff_sum < coeff);
906    }
907    #[cfg_attr(
908        all(target_arch = "wasm32", not(target_os = "wasi")),
909        wasm_bindgen_test::wasm_bindgen_test
910    )]
911    #[test]
912    fn lr_fit_predict_random() {
913        let x: DenseMatrix<f32> = DenseMatrix::rand(52181, 94);
914        let y1: Vec<i32> = vec![1; 2181];
915        let y2: Vec<i32> = vec![0; 50000];
916        let y: Vec<i32> = y1.into_iter().chain(y2).collect();
917
918        let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
919        let lr_reg = LogisticRegression::fit(
920            &x,
921            &y,
922            LogisticRegressionParameters::default().with_alpha(1.0),
923        )
924        .unwrap();
925
926        let y_hat = lr.predict(&x).unwrap();
927        let y_hat_reg = lr_reg.predict(&x).unwrap();
928
929        assert_eq!(y.len(), y_hat.len());
930        assert_eq!(y.len(), y_hat_reg.len());
931    }
932
933    #[test]
934    fn test_logit() {
935        let x: &DenseMatrix<f64> = &DenseMatrix::rand(52181, 94);
936        let y1: Vec<u32> = vec![1; 2181];
937        let y2: Vec<u32> = vec![0; 50000];
938        let y: &Vec<u32> = &(y1.into_iter().chain(y2).collect());
939        println!("y vec height: {:?}", y.len());
940        println!("x matrix shape: {:?}", x.shape());
941
942        let lr = LogisticRegression::fit(x, y, Default::default()).unwrap();
943        let y_hat = lr.predict(x).unwrap();
944
945        println!("y_hat shape: {:?}", y_hat.shape());
946
947        assert_eq!(y_hat.shape(), 52181);
948    }
949}