1use std::cmp::Ordering;
56use std::fmt::Debug;
57use std::marker::PhantomData;
58
59#[cfg(feature = "serde")]
60use serde::{Deserialize, Serialize};
61
62use crate::api::{Predictor, SupervisedEstimator};
63use crate::error::Failed;
64use crate::linalg::basic::arrays::{Array1, Array2, MutArrayView1};
65use crate::numbers::basenum::Number;
66use crate::numbers::floatnum::FloatNumber;
67use crate::numbers::realnum::RealNumber;
68use crate::optimization::first_order::lbfgs::LBFGS;
69use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
70use crate::optimization::line_search::Backtracking;
71use crate::optimization::FunctionOrder;
72
73#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
74#[derive(Debug, Clone, Eq, PartialEq, Default)]
75pub enum LogisticRegressionSolverName {
77 #[default]
79 LBFGS,
80}
81
82#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
84#[derive(Debug, Clone)]
85pub struct LogisticRegressionParameters<T: Number + FloatNumber> {
86 #[cfg_attr(feature = "serde", serde(default))]
87 pub solver: LogisticRegressionSolverName,
89 #[cfg_attr(feature = "serde", serde(default))]
90 pub alpha: T,
92}
93
94#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
96#[derive(Debug, Clone)]
97pub struct LogisticRegressionSearchParameters<T: Number> {
98 #[cfg_attr(feature = "serde", serde(default))]
99 pub solver: Vec<LogisticRegressionSolverName>,
101 #[cfg_attr(feature = "serde", serde(default))]
102 pub alpha: Vec<T>,
104}
105
106pub struct LogisticRegressionSearchParametersIterator<T: Number> {
108 logistic_regression_search_parameters: LogisticRegressionSearchParameters<T>,
109 current_solver: usize,
110 current_alpha: usize,
111}
112
113impl<T: Number + FloatNumber> IntoIterator for LogisticRegressionSearchParameters<T> {
114 type Item = LogisticRegressionParameters<T>;
115 type IntoIter = LogisticRegressionSearchParametersIterator<T>;
116
117 fn into_iter(self) -> Self::IntoIter {
118 LogisticRegressionSearchParametersIterator {
119 logistic_regression_search_parameters: self,
120 current_solver: 0,
121 current_alpha: 0,
122 }
123 }
124}
125
126impl<T: Number + FloatNumber> Iterator for LogisticRegressionSearchParametersIterator<T> {
127 type Item = LogisticRegressionParameters<T>;
128
129 fn next(&mut self) -> Option<Self::Item> {
130 if self.current_alpha == self.logistic_regression_search_parameters.alpha.len()
131 && self.current_solver == self.logistic_regression_search_parameters.solver.len()
132 {
133 return None;
134 }
135
136 let next = LogisticRegressionParameters {
137 solver: self.logistic_regression_search_parameters.solver[self.current_solver].clone(),
138 alpha: self.logistic_regression_search_parameters.alpha[self.current_alpha],
139 };
140
141 if self.current_alpha + 1 < self.logistic_regression_search_parameters.alpha.len() {
142 self.current_alpha += 1;
143 } else if self.current_solver + 1 < self.logistic_regression_search_parameters.solver.len()
144 {
145 self.current_alpha = 0;
146 self.current_solver += 1;
147 } else {
148 self.current_alpha += 1;
149 self.current_solver += 1;
150 }
151
152 Some(next)
153 }
154}
155
156impl<T: Number + FloatNumber> Default for LogisticRegressionSearchParameters<T> {
157 fn default() -> Self {
158 let default_params = LogisticRegressionParameters::default();
159
160 LogisticRegressionSearchParameters {
161 solver: vec![default_params.solver],
162 alpha: vec![default_params.alpha],
163 }
164 }
165}
166
167#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
169#[derive(Debug)]
170pub struct LogisticRegression<
171 TX: Number + FloatNumber + RealNumber,
172 TY: Number + Ord,
173 X: Array2<TX>,
174 Y: Array1<TY>,
175> {
176 coefficients: Option<X>,
177 intercept: Option<X>,
178 classes: Option<Vec<TY>>,
179 num_attributes: usize,
180 num_classes: usize,
181 _phantom_tx: PhantomData<TX>,
182 _phantom_y: PhantomData<Y>,
183}
184
185trait ObjectiveFunction<T: Number + FloatNumber, X: Array2<T>> {
186 fn f(&self, w_bias: &[T]) -> T;
187
188 #[allow(clippy::ptr_arg)]
189 fn df(&self, g: &mut Vec<T>, w_bias: &Vec<T>);
190
191 #[allow(clippy::ptr_arg)]
192 fn partial_dot(w: &[T], x: &X, v_col: usize, m_row: usize) -> T {
193 let mut sum = T::zero();
194 let p = x.shape().1;
195 for i in 0..p {
196 sum += *x.get((m_row, i)) * w[i + v_col];
197 }
198
199 sum + w[p + v_col]
200 }
201}
202
203struct BinaryObjectiveFunction<'a, T: Number + FloatNumber, X: Array2<T>> {
204 x: &'a X,
205 y: Vec<usize>,
206 alpha: T,
207 _phantom_t: PhantomData<T>,
208}
209
210impl<T: Number + FloatNumber> LogisticRegressionParameters<T> {
211 pub fn with_solver(mut self, solver: LogisticRegressionSolverName) -> Self {
213 self.solver = solver;
214 self
215 }
216 pub fn with_alpha(mut self, alpha: T) -> Self {
218 self.alpha = alpha;
219 self
220 }
221}
222
223impl<T: Number + FloatNumber> Default for LogisticRegressionParameters<T> {
224 fn default() -> Self {
225 LogisticRegressionParameters {
226 solver: LogisticRegressionSolverName::default(),
227 alpha: T::zero(),
228 }
229 }
230}
231
232impl<TX: Number + FloatNumber + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
233 PartialEq for LogisticRegression<TX, TY, X, Y>
234{
235 fn eq(&self, other: &Self) -> bool {
236 if self.num_classes != other.num_classes
237 || self.num_attributes != other.num_attributes
238 || self.classes().len() != other.classes().len()
239 {
240 false
241 } else {
242 for i in 0..self.classes().len() {
243 if self.classes()[i] != other.classes()[i] {
244 return false;
245 }
246 }
247
248 self.coefficients()
249 .iterator(0)
250 .zip(other.coefficients().iterator(0))
251 .all(|(&a, &b)| (a - b).abs() <= TX::epsilon())
252 && self
253 .intercept()
254 .iterator(0)
255 .zip(other.intercept().iterator(0))
256 .all(|(&a, &b)| (a - b).abs() <= TX::epsilon())
257 }
258 }
259}
260
261impl<T: Number + FloatNumber, X: Array2<T>> ObjectiveFunction<T, X>
262 for BinaryObjectiveFunction<'_, T, X>
263{
264 fn f(&self, w_bias: &[T]) -> T {
265 let mut f = T::zero();
266 let (n, p) = self.x.shape();
267
268 for i in 0..n {
269 let wx = BinaryObjectiveFunction::partial_dot(w_bias, self.x, 0, i);
270 f += wx.ln_1pe() - (T::from(self.y[i]).unwrap()) * wx;
271 }
272
273 if self.alpha > T::zero() {
274 let mut w_squared = T::zero();
275 for w_bias_i in w_bias.iter().take(p) {
276 w_squared += *w_bias_i * *w_bias_i;
277 }
278 f += T::from_f64(0.5).unwrap() * self.alpha * w_squared;
279 }
280
281 f
282 }
283
284 fn df(&self, g: &mut Vec<T>, w_bias: &Vec<T>) {
285 g.copy_from(&Vec::zeros(g.len()));
286
287 let (n, p) = self.x.shape();
288
289 for i in 0..n {
290 let wx = BinaryObjectiveFunction::partial_dot(w_bias, self.x, 0, i);
291
292 let dyi = (T::from(self.y[i]).unwrap()) - wx.sigmoid();
293 for (j, g_j) in g.iter_mut().enumerate().take(p) {
294 *g_j -= dyi * *self.x.get((i, j));
295 }
296 g[p] -= dyi;
297 }
298
299 if self.alpha > T::zero() {
300 for i in 0..p {
301 let w = w_bias[i];
302 g[i] += self.alpha * w;
303 }
304 }
305 }
306}
307
308struct MultiClassObjectiveFunction<'a, T: Number + FloatNumber, X: Array2<T>> {
309 x: &'a X,
310 y: Vec<usize>,
311 k: usize,
312 alpha: T,
313 _phantom_t: PhantomData<T>,
314}
315
316impl<T: Number + FloatNumber + RealNumber, X: Array2<T>> ObjectiveFunction<T, X>
317 for MultiClassObjectiveFunction<'_, T, X>
318{
319 fn f(&self, w_bias: &[T]) -> T {
320 let mut f = T::zero();
321 let mut prob = vec![T::zero(); self.k];
322 let (n, p) = self.x.shape();
323 for i in 0..n {
324 for (j, prob_j) in prob.iter_mut().enumerate().take(self.k) {
325 *prob_j = MultiClassObjectiveFunction::partial_dot(w_bias, self.x, j * (p + 1), i);
326 }
327 prob.softmax_mut();
328 f -= prob[self.y[i]].ln();
329 }
330
331 if self.alpha > T::zero() {
332 let mut w_squared = T::zero();
333 for i in 0..self.k {
334 for j in 0..p {
335 let wi = w_bias[i * (p + 1) + j];
336 w_squared += wi * wi;
337 }
338 }
339 f += T::from_f64(0.5).unwrap() * self.alpha * w_squared;
340 }
341
342 f
343 }
344
345 fn df(&self, g: &mut Vec<T>, w: &Vec<T>) {
346 g.copy_from(&Vec::zeros(g.len()));
347
348 let mut prob = vec![T::zero(); self.k];
349 let (n, p) = self.x.shape();
350
351 for i in 0..n {
352 for (j, prob_j) in prob.iter_mut().enumerate().take(self.k) {
353 *prob_j = MultiClassObjectiveFunction::partial_dot(w, self.x, j * (p + 1), i);
354 }
355
356 prob.softmax_mut();
357
358 for j in 0..self.k {
359 let yi = (if self.y[i] == j { T::one() } else { T::zero() }) - prob[j];
360
361 for l in 0..p {
362 let pos = j * (p + 1);
363 g[pos + l] -= yi * *self.x.get((i, l));
364 }
365 g[j * (p + 1) + p] -= yi;
366 }
367 }
368
369 if self.alpha > T::zero() {
370 for i in 0..self.k {
371 for j in 0..p {
372 let pos = i * (p + 1);
373 let wi = w[pos + j];
374 g[pos + j] += self.alpha * wi;
375 }
376 }
377 }
378 }
379}
380
381impl<TX: Number + FloatNumber + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
382 SupervisedEstimator<X, Y, LogisticRegressionParameters<TX>>
383 for LogisticRegression<TX, TY, X, Y>
384{
385 fn new() -> Self {
386 Self {
387 coefficients: Option::None,
388 intercept: Option::None,
389 classes: Option::None,
390 num_attributes: 0,
391 num_classes: 0,
392 _phantom_tx: PhantomData,
393 _phantom_y: PhantomData,
394 }
395 }
396
397 fn fit(x: &X, y: &Y, parameters: LogisticRegressionParameters<TX>) -> Result<Self, Failed> {
398 LogisticRegression::fit(x, y, parameters)
399 }
400}
401
402impl<TX: Number + FloatNumber + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
403 Predictor<X, Y> for LogisticRegression<TX, TY, X, Y>
404{
405 fn predict(&self, x: &X) -> Result<Y, Failed> {
406 self.predict(x)
407 }
408}
409
410impl<TX: Number + FloatNumber + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
411 LogisticRegression<TX, TY, X, Y>
412{
413 pub fn fit(
418 x: &X,
419 y: &Y,
420 parameters: LogisticRegressionParameters<TX>,
421 ) -> Result<LogisticRegression<TX, TY, X, Y>, Failed> {
422 let (x_nrows, num_attributes) = x.shape();
423 let y_nrows = y.shape();
424
425 if x_nrows != y_nrows {
426 return Err(Failed::fit(
427 "Number of rows of X doesn\'t match number of rows of Y",
428 ));
429 }
430
431 let classes = y.unique();
432
433 let k = classes.len();
434
435 let mut yi: Vec<usize> = vec![0; y_nrows];
436
437 for (i, yi_i) in yi.iter_mut().enumerate().take(y_nrows) {
438 let yc = y.get(i);
439 *yi_i = classes.iter().position(|c| yc == c).unwrap();
440 }
441
442 match k.cmp(&2) {
443 Ordering::Less => Err(Failed::fit(&format!(
444 "incorrect number of classes: {k}. Should be >= 2."
445 ))),
446 Ordering::Equal => {
447 let x0 = Vec::zeros(num_attributes + 1);
448
449 let objective = BinaryObjectiveFunction {
450 x,
451 y: yi,
452 alpha: parameters.alpha,
453 _phantom_t: PhantomData,
454 };
455
456 let result = Self::minimize(x0, objective);
457
458 let weights = X::from_iterator(result.x.into_iter(), 1, num_attributes + 1, 0);
459 let coefficients = weights.slice(0..1, 0..num_attributes);
460 let intercept = weights.slice(0..1, num_attributes..num_attributes + 1);
461
462 Ok(LogisticRegression {
463 coefficients: Some(X::from_slice(coefficients.as_ref())),
464 intercept: Some(X::from_slice(intercept.as_ref())),
465 classes: Some(classes),
466 num_attributes,
467 num_classes: k,
468 _phantom_tx: PhantomData,
469 _phantom_y: PhantomData,
470 })
471 }
472 Ordering::Greater => {
473 let x0 = Vec::zeros((num_attributes + 1) * k);
474
475 let objective = MultiClassObjectiveFunction {
476 x,
477 y: yi,
478 k,
479 alpha: parameters.alpha,
480 _phantom_t: PhantomData,
481 };
482
483 let result = Self::minimize(x0, objective);
484 let weights = X::from_iterator(result.x.into_iter(), k, num_attributes + 1, 0);
485 let coefficients = weights.slice(0..k, 0..num_attributes);
486 let intercept = weights.slice(0..k, num_attributes..num_attributes + 1);
487
488 Ok(LogisticRegression {
489 coefficients: Some(X::from_slice(coefficients.as_ref())),
490 intercept: Some(X::from_slice(intercept.as_ref())),
491 classes: Some(classes),
492 num_attributes,
493 num_classes: k,
494 _phantom_tx: PhantomData,
495 _phantom_y: PhantomData,
496 })
497 }
498 }
499 }
500
501 pub fn predict(&self, x: &X) -> Result<Y, Failed> {
504 let n = x.shape().0;
505 let mut result = Y::zeros(n);
506 if self.num_classes == 2 {
507 let y_hat = x.ab(false, self.coefficients(), true);
508 let intercept = *self.intercept().get((0, 0));
509 for (i, y_hat_i) in y_hat.iterator(0).enumerate().take(n) {
510 result.set(
511 i,
512 self.classes()[usize::from(
513 RealNumber::sigmoid(*y_hat_i + intercept) > RealNumber::half(),
514 )],
515 );
516 }
517 } else {
518 let mut y_hat = x.matmul(&self.coefficients().transpose());
519 for r in 0..n {
520 for c in 0..self.num_classes {
521 y_hat.set((r, c), *y_hat.get((r, c)) + *self.intercept().get((c, 0)));
522 }
523 }
524 let class_idxs = y_hat.argmax(1);
525 for (i, class_i) in class_idxs.iter().enumerate().take(n) {
526 result.set(i, self.classes()[*class_i]);
527 }
528 }
529 Ok(result)
530 }
531
532 pub fn coefficients(&self) -> &X {
534 self.coefficients.as_ref().unwrap()
535 }
536
537 pub fn intercept(&self) -> &X {
539 self.intercept.as_ref().unwrap()
540 }
541
542 pub fn classes(&self) -> &Vec<TY> {
544 self.classes.as_ref().unwrap()
545 }
546
547 fn minimize(
548 x0: Vec<TX>,
549 objective: impl ObjectiveFunction<TX, X>,
550 ) -> OptimizerResult<TX, Vec<TX>> {
551 let f = |w: &Vec<TX>| -> TX { objective.f(w) };
552
553 let df = |g: &mut Vec<TX>, w: &Vec<TX>| objective.df(g, w);
554
555 let ls: Backtracking<TX> = Backtracking {
556 order: FunctionOrder::THIRD,
557 ..Default::default()
558 };
559 let optimizer: LBFGS = Default::default();
560
561 optimizer.optimize(&f, &df, &x0, &ls)
562 }
563}
564
565#[cfg(test)]
566mod tests {
567 use super::*;
568
569 #[cfg(feature = "datasets")]
570 use crate::dataset::generator::make_blobs;
571 use crate::linalg::basic::arrays::Array;
572 use crate::linalg::basic::matrix::DenseMatrix;
573
574 #[test]
575 fn search_parameters() {
576 let parameters = LogisticRegressionSearchParameters {
577 alpha: vec![0., 1.],
578 ..Default::default()
579 };
580 let mut iter = parameters.into_iter();
581 assert_eq!(iter.next().unwrap().alpha, 0.);
582 assert_eq!(
583 iter.next().unwrap().solver,
584 LogisticRegressionSolverName::LBFGS
585 );
586 assert!(iter.next().is_none());
587 }
588
589 #[cfg_attr(
590 all(target_arch = "wasm32", not(target_os = "wasi")),
591 wasm_bindgen_test::wasm_bindgen_test
592 )]
593 #[test]
594 fn multiclass_objective_f() {
595 let x = DenseMatrix::from_2d_array(&[
596 &[1., -5.],
597 &[2., 5.],
598 &[3., -2.],
599 &[1., 2.],
600 &[2., 0.],
601 &[6., -5.],
602 &[7., 5.],
603 &[6., -2.],
604 &[7., 2.],
605 &[6., 0.],
606 &[8., -5.],
607 &[9., 5.],
608 &[10., -2.],
609 &[8., 2.],
610 &[9., 0.],
611 ])
612 .unwrap();
613
614 let y = vec![0, 0, 1, 1, 2, 1, 1, 0, 0, 2, 1, 1, 0, 0, 1];
615
616 let objective = MultiClassObjectiveFunction {
617 x: &x,
618 y: y.clone(),
619 k: 3,
620 alpha: 0.0,
621 _phantom_t: PhantomData,
622 };
623
624 let mut g = vec![0f64; 9];
625
626 objective.df(&mut g, &vec![1., 2., 3., 4., 5., 6., 7., 8., 9.]);
627 objective.df(&mut g, &vec![1., 2., 3., 4., 5., 6., 7., 8., 9.]);
628
629 assert!((g[0] + 33.000068218163484).abs() < f64::EPSILON);
630
631 let f = objective.f(&[1., 2., 3., 4., 5., 6., 7., 8., 9.]);
632
633 assert!((f - 408.0052230582765).abs() < f64::EPSILON);
634
635 let objective_reg = MultiClassObjectiveFunction {
636 x: &x,
637 y,
638 k: 3,
639 alpha: 1.0,
640 _phantom_t: PhantomData,
641 };
642
643 let f = objective_reg.f(&[1., 2., 3., 4., 5., 6., 7., 8., 9.]);
644 assert!((f - 487.5052).abs() < 1e-4);
645
646 objective_reg.df(&mut g, &vec![1., 2., 3., 4., 5., 6., 7., 8., 9.]);
647 assert!((g[0].abs() - 32.0).abs() < 1e-4);
648 }
649
650 #[cfg_attr(
651 all(target_arch = "wasm32", not(target_os = "wasi")),
652 wasm_bindgen_test::wasm_bindgen_test
653 )]
654 #[test]
655 fn binary_objective_f() {
656 let x = DenseMatrix::from_2d_array(&[
657 &[1., -5.],
658 &[2., 5.],
659 &[3., -2.],
660 &[1., 2.],
661 &[2., 0.],
662 &[6., -5.],
663 &[7., 5.],
664 &[6., -2.],
665 &[7., 2.],
666 &[6., 0.],
667 &[8., -5.],
668 &[9., 5.],
669 &[10., -2.],
670 &[8., 2.],
671 &[9., 0.],
672 ])
673 .unwrap();
674
675 let y = vec![0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1];
676
677 let objective = BinaryObjectiveFunction {
678 x: &x,
679 y: y.clone(),
680 alpha: 0.0,
681 _phantom_t: PhantomData,
682 };
683
684 let mut g = vec![0f64; 3];
685
686 objective.df(&mut g, &vec![1., 2., 3.]);
687 objective.df(&mut g, &vec![1., 2., 3.]);
688
689 assert!((g[0] - 26.051064349381285).abs() < f64::EPSILON);
690 assert!((g[1] - 10.239000702928523).abs() < f64::EPSILON);
691 assert!((g[2] - 3.869294270156324).abs() < f64::EPSILON);
692
693 let f = objective.f(&[1., 2., 3.]);
694
695 assert!((f - 59.76994756647412).abs() < f64::EPSILON);
696
697 let objective_reg = BinaryObjectiveFunction {
698 x: &x,
699 y,
700 alpha: 1.0,
701 _phantom_t: PhantomData,
702 };
703
704 let f = objective_reg.f(&[1., 2., 3.]);
705 assert!((f - 62.2699).abs() < 1e-4);
706
707 objective_reg.df(&mut g, &vec![1., 2., 3.]);
708 assert!((g[0] - 27.0511).abs() < 1e-4);
709 assert!((g[1] - 12.239).abs() < 1e-4);
710 assert!((g[2] - 3.8693).abs() < 1e-4);
711 }
712
713 #[cfg_attr(
714 all(target_arch = "wasm32", not(target_os = "wasi")),
715 wasm_bindgen_test::wasm_bindgen_test
716 )]
717 #[test]
718 fn lr_fit_predict() {
719 let x: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
720 &[1., -5.],
721 &[2., 5.],
722 &[3., -2.],
723 &[1., 2.],
724 &[2., 0.],
725 &[6., -5.],
726 &[7., 5.],
727 &[6., -2.],
728 &[7., 2.],
729 &[6., 0.],
730 &[8., -5.],
731 &[9., 5.],
732 &[10., -2.],
733 &[8., 2.],
734 &[9., 0.],
735 ])
736 .unwrap();
737 let y: Vec<i32> = vec![0, 0, 1, 1, 2, 1, 1, 0, 0, 2, 1, 1, 0, 0, 1];
738
739 let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
740
741 assert_eq!(lr.coefficients().shape(), (3, 2));
742 assert_eq!(lr.intercept().shape(), (3, 1));
743
744 assert!((*lr.coefficients().get((0, 0)) - 0.0435).abs() < 1e-4);
745 assert!(
746 (*lr.intercept().get((0, 0)) - 0.1250).abs() < 1e-4,
747 "expected to be least than 1e-4, got {}",
748 (*lr.intercept().get((0, 0)) - 0.1250).abs()
749 );
750
751 let y_hat = lr.predict(&x).unwrap();
752
753 assert_eq!(y_hat, vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]);
754 }
755
756 #[cfg(feature = "datasets")]
757 #[cfg_attr(
758 all(target_arch = "wasm32", not(target_os = "wasi")),
759 wasm_bindgen_test::wasm_bindgen_test
760 )]
761 #[test]
762 fn lr_fit_predict_multiclass() {
763 let blobs = make_blobs(15, 4, 3);
764
765 let x: DenseMatrix<f32> = DenseMatrix::from_iterator(blobs.data.into_iter(), 15, 4, 0);
766 let y: Vec<i32> = blobs.target.into_iter().map(|v| v as i32).collect();
767
768 let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
769
770 let y_hat = lr.predict(&x).unwrap();
771
772 assert_eq!(y_hat, vec![0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]);
773
774 let lr_reg = LogisticRegression::fit(
775 &x,
776 &y,
777 LogisticRegressionParameters::default().with_alpha(10.0),
778 )
779 .unwrap();
780
781 let reg_coeff_sum: f32 = lr_reg.coefficients().abs().iter().sum();
782 let coeff: f32 = lr.coefficients().abs().iter().sum();
783
784 assert!(reg_coeff_sum < coeff);
785 }
786
787 #[cfg(feature = "datasets")]
788 #[cfg_attr(
789 all(target_arch = "wasm32", not(target_os = "wasi")),
790 wasm_bindgen_test::wasm_bindgen_test
791 )]
792 #[test]
793 fn lr_fit_predict_binary() {
794 let blobs = make_blobs(20, 4, 2);
795
796 let x = DenseMatrix::from_iterator(blobs.data.into_iter(), 20, 4, 0);
797 let y: Vec<i32> = blobs.target.into_iter().map(|v| v as i32).collect();
798
799 let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
800
801 let y_hat = lr.predict(&x).unwrap();
802
803 assert_eq!(
804 y_hat,
805 vec![0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
806 );
807
808 let lr_reg = LogisticRegression::fit(
809 &x,
810 &y,
811 LogisticRegressionParameters::default().with_alpha(10.0),
812 )
813 .unwrap();
814
815 let reg_coeff_sum: f32 = lr_reg.coefficients().abs().iter().sum();
816 let coeff: f32 = lr.coefficients().abs().iter().sum();
817
818 assert!(reg_coeff_sum < coeff);
819 }
820
821 #[cfg_attr(
823 all(target_arch = "wasm32", not(target_os = "wasi")),
824 wasm_bindgen_test::wasm_bindgen_test
825 )]
826 #[test]
827 #[cfg(feature = "serde")]
828 fn serde() {
829 let x: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
830 &[1., -5.],
831 &[2., 5.],
832 &[3., -2.],
833 &[1., 2.],
834 &[2., 0.],
835 &[6., -5.],
836 &[7., 5.],
837 &[6., -2.],
838 &[7., 2.],
839 &[6., 0.],
840 &[8., -5.],
841 &[9., 5.],
842 &[10., -2.],
843 &[8., 2.],
844 &[9., 0.],
845 ])
846 .unwrap();
847 let y: Vec<i32> = vec![0, 0, 1, 1, 2, 1, 1, 0, 0, 2, 1, 1, 0, 0, 1];
848
849 let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
850
851 let deserialized_lr: LogisticRegression<f64, i32, DenseMatrix<f64>, Vec<i32>> =
852 serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
853
854 assert_eq!(lr, deserialized_lr);
855 }
856
857 #[cfg_attr(
858 all(target_arch = "wasm32", not(target_os = "wasi")),
859 wasm_bindgen_test::wasm_bindgen_test
860 )]
861 #[test]
862 fn lr_fit_predict_iris() {
863 let x = DenseMatrix::from_2d_array(&[
864 &[5.1, 3.5, 1.4, 0.2],
865 &[4.9, 3.0, 1.4, 0.2],
866 &[4.7, 3.2, 1.3, 0.2],
867 &[4.6, 3.1, 1.5, 0.2],
868 &[5.0, 3.6, 1.4, 0.2],
869 &[5.4, 3.9, 1.7, 0.4],
870 &[4.6, 3.4, 1.4, 0.3],
871 &[5.0, 3.4, 1.5, 0.2],
872 &[4.4, 2.9, 1.4, 0.2],
873 &[4.9, 3.1, 1.5, 0.1],
874 &[7.0, 3.2, 4.7, 1.4],
875 &[6.4, 3.2, 4.5, 1.5],
876 &[6.9, 3.1, 4.9, 1.5],
877 &[5.5, 2.3, 4.0, 1.3],
878 &[6.5, 2.8, 4.6, 1.5],
879 &[5.7, 2.8, 4.5, 1.3],
880 &[6.3, 3.3, 4.7, 1.6],
881 &[4.9, 2.4, 3.3, 1.0],
882 &[6.6, 2.9, 4.6, 1.3],
883 &[5.2, 2.7, 3.9, 1.4],
884 ])
885 .unwrap();
886 let y: Vec<i32> = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
887
888 let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
889 let lr_reg = LogisticRegression::fit(
890 &x,
891 &y,
892 LogisticRegressionParameters::default().with_alpha(1.0),
893 )
894 .unwrap();
895
896 let y_hat = lr.predict(&x).unwrap();
897
898 let error: i32 = y.into_iter().zip(y_hat).map(|(a, b)| (a - b).abs()).sum();
899
900 assert!(error <= 1);
901
902 let reg_coeff_sum: f32 = lr_reg.coefficients().abs().iter().sum();
903 let coeff: f32 = lr.coefficients().abs().iter().sum();
904
905 assert!(reg_coeff_sum < coeff);
906 }
907 #[cfg_attr(
908 all(target_arch = "wasm32", not(target_os = "wasi")),
909 wasm_bindgen_test::wasm_bindgen_test
910 )]
911 #[test]
912 fn lr_fit_predict_random() {
913 let x: DenseMatrix<f32> = DenseMatrix::rand(52181, 94);
914 let y1: Vec<i32> = vec![1; 2181];
915 let y2: Vec<i32> = vec![0; 50000];
916 let y: Vec<i32> = y1.into_iter().chain(y2).collect();
917
918 let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
919 let lr_reg = LogisticRegression::fit(
920 &x,
921 &y,
922 LogisticRegressionParameters::default().with_alpha(1.0),
923 )
924 .unwrap();
925
926 let y_hat = lr.predict(&x).unwrap();
927 let y_hat_reg = lr_reg.predict(&x).unwrap();
928
929 assert_eq!(y.len(), y_hat.len());
930 assert_eq!(y.len(), y_hat_reg.len());
931 }
932
933 #[test]
934 fn test_logit() {
935 let x: &DenseMatrix<f64> = &DenseMatrix::rand(52181, 94);
936 let y1: Vec<u32> = vec![1; 2181];
937 let y2: Vec<u32> = vec![0; 50000];
938 let y: &Vec<u32> = &(y1.into_iter().chain(y2).collect());
939 println!("y vec height: {:?}", y.len());
940 println!("x matrix shape: {:?}", x.shape());
941
942 let lr = LogisticRegression::fit(x, y, Default::default()).unwrap();
943 let y_hat = lr.predict(x).unwrap();
944
945 println!("y_hat shape: {:?}", y_hat.shape());
946
947 assert_eq!(y_hat.shape(), 52181);
948 }
949}