sklears_python/linear/
logistic_regression.rs

1//! Python bindings for Logistic Regression
2//!
3//! This module provides Python bindings for Logistic Regression,
4//! offering scikit-learn compatible interfaces for binary classification
5//! using the sklears-linear crate.
6
7use super::common::*;
8use pyo3::types::PyDict;
9use pyo3::Bound;
10use sklears_core::traits::{Fit, Predict, PredictProba, Score, Trained};
11use sklears_linear::{LogisticRegression, LogisticRegressionConfig, Penalty, Solver};
12
13/// Python-specific configuration wrapper for LogisticRegression
14#[derive(Debug, Clone)]
15pub struct PyLogisticRegressionConfig {
16    pub penalty: String,
17    pub c: f64,
18    pub fit_intercept: bool,
19    pub max_iter: usize,
20    pub tol: f64,
21    pub solver: String,
22    pub random_state: Option<i32>,
23    pub class_weight: Option<String>,
24    pub multi_class: String,
25    pub warm_start: bool,
26    pub n_jobs: Option<i32>,
27    pub l1_ratio: Option<f64>,
28}
29
30impl Default for PyLogisticRegressionConfig {
31    fn default() -> Self {
32        Self {
33            penalty: "l2".to_string(),
34            c: 1.0,
35            fit_intercept: true,
36            max_iter: 100,
37            tol: 1e-4,
38            solver: "lbfgs".to_string(),
39            random_state: None,
40            class_weight: None,
41            multi_class: "auto".to_string(),
42            warm_start: false,
43            n_jobs: None,
44            l1_ratio: None,
45        }
46    }
47}
48
49impl From<PyLogisticRegressionConfig> for LogisticRegressionConfig {
50    fn from(py_config: PyLogisticRegressionConfig) -> Self {
51        // Convert penalty string to Penalty enum
52        let penalty = match py_config.penalty.as_str() {
53            "l1" => Penalty::L1(1.0 / py_config.c),
54            "l2" => Penalty::L2(1.0 / py_config.c),
55            "elasticnet" => Penalty::ElasticNet {
56                alpha: 1.0 / py_config.c,
57                l1_ratio: py_config.l1_ratio.unwrap_or(0.5),
58            },
59            _ => Penalty::L2(1.0 / py_config.c), // Default to L2
60        };
61
62        // Convert solver string to Solver enum
63        let solver = match py_config.solver.as_str() {
64            "lbfgs" => Solver::Lbfgs,
65            "sag" => Solver::Sag,
66            "saga" => Solver::Saga,
67            "newton-cg" => Solver::Newton,
68            _ => Solver::Auto, // Default to Auto
69        };
70
71        LogisticRegressionConfig {
72            penalty,
73            solver,
74            max_iter: py_config.max_iter,
75            tol: py_config.tol,
76            fit_intercept: py_config.fit_intercept,
77            random_state: py_config.random_state.map(|s| s as u64),
78        }
79    }
80}
81
82/// Logistic Regression (aka logit, MaxEnt) classifier.
83///
84/// In the multiclass case, the training algorithm uses the one-vs-rest (OvR)
85/// scheme if the 'multi_class' option is set to 'ovr', and uses the
86/// cross-entropy loss if the 'multi_class' option is set to 'multinomial'.
87/// (Currently the 'multinomial' option is supported only by the 'lbfgs',
88/// 'sag', 'saga' and 'newton-cg' solvers.)
89///
90/// This class implements regularized logistic regression using various solvers.
91/// **Note that regularization is applied by default**. It can handle both
92/// dense and sparse input. Use C-ordered arrays containing
93/// 64-bit floats for optimal performance; any other input format will be
94/// converted (and copied).
95///
96/// The 'newton-cg', 'sag', and 'lbfgs' solvers support only L2 regularization
97/// with primal formulation, or no regularization. The Elastic-Net regularization
98/// is only supported by the 'saga' solver.
99///
100/// Parameters
101/// ----------
102/// penalty : {'l1', 'l2', 'elasticnet'}, default='l2'
103///     Specify the norm of the penalty:
104///
105///     - 'l2': add a L2 penalty term and it is the default choice;
106///     - 'l1': add a L1 penalty term;
107///     - 'elasticnet': both L1 and L2 penalty terms are added.
108///
109/// tol : float, default=1e-4
110///     Tolerance for stopping criteria.
111///
112/// C : float, default=1.0
113///     Inverse of regularization strength; must be a positive float.
114///     Like in support vector machines, smaller values specify stronger
115///     regularization.
116///
117/// fit_intercept : bool, default=True
118///     Specifies if a constant (a.k.a. bias or intercept) should be
119///     added to the decision function.
120///
121/// class_weight : dict or 'balanced', default=None
122///     Weights associated with classes in the form ``{class_label: weight}``.
123///     If not given, all classes are supposed to have weight one.
124///
125///     The "balanced" mode uses the values of y to automatically adjust
126///     weights inversely proportional to class frequencies in the input data
127///     as ``n_samples / (n_classes * np.bincount(y))``.
128///
129/// random_state : int, default=None
130///     Used when ``solver`` == 'sag', 'saga' to shuffle the
131///     data. See :term:`Glossary <random_state>` for details.
132///
133/// solver : {'lbfgs', 'newton-cg', 'sag', 'saga'}, default='lbfgs'
134///
135///     Algorithm to use in the optimization problem. Default is 'lbfgs'.
136///     To choose a solver, you might want to consider the following aspects:
137///
138///         - For small datasets, 'lbfgs' is a good choice, whereas 'sag'
139///           and 'saga' are faster for large ones;
140///         - For multiclass problems, only 'newton-cg', 'sag', 'saga' and
141///           'lbfgs' handle multinomial loss.
142///
143/// max_iter : int, default=100
144///     Maximum number of iterations taken for the solvers to converge.
145///
146/// multi_class : {'auto', 'ovr', 'multinomial'}, default='auto'
147///     If the option chosen is 'ovr', then a binary problem is fit for each
148///     label. For 'multinomial' the loss minimised is the multinomial loss fit
149///     across the entire probability distribution, *even when the data is
150///     binary*. 'auto' selects 'ovr' if the data is binary,
151///     and otherwise selects 'multinomial'.
152///
153/// warm_start : bool, default=False
154///     When set to True, reuse the solution of the previous call to fit as
155///     initialization, otherwise, just erase the previous solution.
156///     See :term:`the Glossary <warm_start>`.
157///
158/// n_jobs : int, default=None
159///     Number of CPU cores used when parallelizing over classes if
160///     multi_class='ovr'". ``None`` means 1 unless in a
161///     :obj:`joblib.parallel_backend` context. ``-1`` means using all
162///     processors. See :term:`Glossary <n_jobs>` for more details.
163///
164/// l1_ratio : float, default=None
165///     The Elastic-Net mixing parameter, with ``0 <= l1_ratio <= 1``. Only
166///     used if ``penalty='elasticnet'``. Setting ``l1_ratio=0`` is equivalent
167///     to using ``penalty='l2'``, while setting ``l1_ratio=1`` is equivalent
168///     to using ``penalty='l1'``. For ``0 < l1_ratio <1``, the penalty is a
169///     combination of L1 and L2.
170///
171/// Attributes
172/// ----------
173/// classes_ : ndarray of shape (n_classes, )
174///     A list of class labels known to the classifier.
175///
176/// coef_ : ndarray of shape (1, n_features) or (n_classes, n_features)
177///     Coefficient of the features in the decision function.
178///
179///     `coef_` is of shape (1, n_features) when the given problem is binary.
180///
181/// intercept_ : float or ndarray of shape (n_classes,)
182///     Intercept (a.k.a. bias) added to the decision function.
183///
184///     If `fit_intercept` is set to False, the intercept is set to zero.
185///     `intercept_` is of shape (1,) when the given problem is binary.
186///
187/// n_features_in_ : int
188///     Number of features seen during :term:`fit`.
189///
190/// Examples
191/// --------
192/// >>> from sklears_python import LogisticRegression
193/// >>> import numpy as np
194/// >>> X = np.array([[-1, -1], [-2, -1], [1, 1], [2, 1]])
195/// >>> y = np.array([0, 0, 1, 1])
196/// >>> clf = LogisticRegression(random_state=0).fit(X, y)
197/// >>> clf.predict(X[:2, :])
198/// array([0, 0])
199/// >>> clf.predict_proba(X[:2, :])
200/// array([[...]])
201/// >>> clf.score(X, y)
202/// 1.0
203///
204/// Notes
205/// -----
206/// The underlying implementation uses optimized solvers from sklears-linear.
207///
208/// References
209/// ----------
210/// L-BFGS-B -- Software for Large-scale Bound-constrained Optimization
211/// Ciyou Zhu, Richard Byrd, Jorge Nocedal and Jose Luis Morales.
212/// http://users.iems.northwestern.edu/~nocedal/lbfgsb.html
213///
214/// SAG -- Mark Schmidt, Nicolas Le Roux, and Francis Bach
215/// Minimizing Finite Sums with the Stochastic Average Gradient
216/// https://hal.inria.fr/hal-00860051/document
217///
218/// SAGA -- Defazio, A., Bach F. & Lacoste-Julien S. (2014).
219/// SAGA: A Fast Incremental Gradient Method With Support
220/// for Non-Strongly Convex Composite Objectives
221/// https://arxiv.org/abs/1407.0202
222#[pyclass(name = "LogisticRegression")]
223pub struct PyLogisticRegression {
224    py_config: PyLogisticRegressionConfig,
225    fitted_model: Option<LogisticRegression<Trained>>,
226}
227
228#[pymethods]
229impl PyLogisticRegression {
230    #[new]
231    #[allow(clippy::too_many_arguments)]
232    #[pyo3(signature = (penalty="l2", dual=false, tol=1e-4, c=1.0, fit_intercept=true, intercept_scaling=1.0, class_weight=None, random_state=None, solver="lbfgs", max_iter=100, multi_class="auto", verbose=0, warm_start=false, n_jobs=None, l1_ratio=None))]
233    fn new(
234        penalty: &str,
235        dual: bool,
236        tol: f64,
237        c: f64,
238        fit_intercept: bool,
239        intercept_scaling: f64,
240        class_weight: Option<&str>,
241        random_state: Option<i32>,
242        solver: &str,
243        max_iter: usize,
244        multi_class: &str,
245        verbose: i32,
246        warm_start: bool,
247        n_jobs: Option<i32>,
248        l1_ratio: Option<f64>,
249    ) -> Self {
250        // Note: Some parameters are sklearn-specific and don't directly map to our implementation
251        let _dual = dual;
252        let _intercept_scaling = intercept_scaling;
253        let _verbose = verbose;
254
255        let py_config = PyLogisticRegressionConfig {
256            penalty: penalty.to_string(),
257            c,
258            fit_intercept,
259            max_iter,
260            tol,
261            solver: solver.to_string(),
262            random_state,
263            class_weight: class_weight.map(|s| s.to_string()),
264            multi_class: multi_class.to_string(),
265            warm_start,
266            n_jobs,
267            l1_ratio,
268        };
269
270        Self {
271            py_config,
272            fitted_model: None,
273        }
274    }
275
276    /// Fit the logistic regression model
277    fn fit(&mut self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<()> {
278        let x_array = pyarray_to_core_array2(x)?;
279        let y_array = pyarray_to_core_array1(y)?;
280
281        // Validate input arrays
282        validate_fit_arrays(&x_array, &y_array)?;
283
284        // Create sklears-linear model with Logistic Regression configuration
285        let model = LogisticRegression::new()
286            .max_iter(self.py_config.max_iter)
287            .fit_intercept(self.py_config.fit_intercept);
288
289        // Apply penalty if specified
290        let model = match self.py_config.penalty.as_str() {
291            "l1" => model.penalty(Penalty::L1(1.0 / self.py_config.c)),
292            "l2" => model.penalty(Penalty::L2(1.0 / self.py_config.c)),
293            "elasticnet" => model.penalty(Penalty::ElasticNet {
294                alpha: 1.0 / self.py_config.c,
295                l1_ratio: self.py_config.l1_ratio.unwrap_or(0.5),
296            }),
297            _ => model, // Default (no additional penalty)
298        };
299
300        // Apply solver if specified
301        let model = match self.py_config.solver.as_str() {
302            "lbfgs" => model.solver(Solver::Lbfgs),
303            "sag" => model.solver(Solver::Sag),
304            "saga" => model.solver(Solver::Saga),
305            "newton-cg" => model.solver(Solver::Newton),
306            _ => model.solver(Solver::Auto),
307        };
308
309        // Apply random state if specified
310        let model = if let Some(rs) = self.py_config.random_state {
311            model.random_state(rs as u64)
312        } else {
313            model
314        };
315
316        // Fit the model using sklears-linear's implementation
317        match model.fit(&x_array, &y_array) {
318            Ok(fitted_model) => {
319                self.fitted_model = Some(fitted_model);
320                Ok(())
321            }
322            Err(e) => Err(PyValueError::new_err(format!(
323                "Failed to fit Logistic Regression model: {:?}",
324                e
325            ))),
326        }
327    }
328
329    /// Predict class labels for samples
330    fn predict(&self, py: Python<'_>, x: PyReadonlyArray2<f64>) -> PyResult<Py<PyArray1<f64>>> {
331        let fitted = self
332            .fitted_model
333            .as_ref()
334            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
335
336        let x_array = pyarray_to_core_array2(x)?;
337        validate_predict_array(&x_array)?;
338
339        match fitted.predict(&x_array) {
340            Ok(predictions) => Ok(core_array1_to_py(py, &predictions)),
341            Err(e) => Err(PyValueError::new_err(format!("Prediction failed: {:?}", e))),
342        }
343    }
344
345    /// Predict class probabilities for samples
346    fn predict_proba(
347        &self,
348        py: Python<'_>,
349        x: PyReadonlyArray2<f64>,
350    ) -> PyResult<Py<PyArray2<f64>>> {
351        let fitted = self
352            .fitted_model
353            .as_ref()
354            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
355
356        let x_array = pyarray_to_core_array2(x)?;
357        validate_predict_array(&x_array)?;
358
359        match fitted.predict_proba(&x_array) {
360            Ok(probabilities) => core_array2_to_py(py, &probabilities),
361            Err(e) => Err(PyValueError::new_err(format!(
362                "Probability prediction failed: {:?}",
363                e
364            ))),
365        }
366    }
367
368    /// Get model coefficients
369    #[getter]
370    fn coef_(&self, py: Python<'_>) -> PyResult<Py<PyArray1<f64>>> {
371        let fitted = self
372            .fitted_model
373            .as_ref()
374            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
375
376        Ok(core_array1_to_py(py, fitted.coef()))
377    }
378
379    /// Get model intercept
380    #[getter]
381    fn intercept_(&self) -> PyResult<f64> {
382        let fitted = self
383            .fitted_model
384            .as_ref()
385            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
386
387        Ok(fitted.intercept().unwrap_or(0.0))
388    }
389
390    /// Get unique class labels
391    #[getter]
392    fn classes_(&self, py: Python<'_>) -> PyResult<Py<PyArray1<f64>>> {
393        let fitted = self
394            .fitted_model
395            .as_ref()
396            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
397
398        Ok(core_array1_to_py(py, fitted.classes()))
399    }
400
401    /// Calculate accuracy score
402    fn score(&self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<f64> {
403        let fitted = self
404            .fitted_model
405            .as_ref()
406            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
407
408        let x_array = pyarray_to_core_array2(x)?;
409        let y_array = pyarray_to_core_array1(y)?;
410
411        match fitted.score(&x_array, &y_array) {
412            Ok(score) => Ok(score),
413            Err(e) => Err(PyValueError::new_err(format!(
414                "Score calculation failed: {:?}",
415                e
416            ))),
417        }
418    }
419
420    /// Get number of features
421    #[getter]
422    fn n_features_in_(&self) -> PyResult<usize> {
423        let fitted = self
424            .fitted_model
425            .as_ref()
426            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
427
428        // Infer number of features from coefficient array length
429        Ok(fitted.coef().len())
430    }
431
432    /// Return parameters for this estimator (sklearn compatibility)
433    fn get_params(&self, py: Python<'_>, deep: Option<bool>) -> PyResult<Py<PyDict>> {
434        let _deep = deep.unwrap_or(true);
435
436        let dict = PyDict::new(py);
437
438        dict.set_item("penalty", &self.py_config.penalty)?;
439        dict.set_item("C", self.py_config.c)?;
440        dict.set_item("fit_intercept", self.py_config.fit_intercept)?;
441        dict.set_item("max_iter", self.py_config.max_iter)?;
442        dict.set_item("tol", self.py_config.tol)?;
443        dict.set_item("solver", &self.py_config.solver)?;
444        dict.set_item("random_state", self.py_config.random_state)?;
445        dict.set_item("class_weight", &self.py_config.class_weight)?;
446        dict.set_item("multi_class", &self.py_config.multi_class)?;
447        dict.set_item("warm_start", self.py_config.warm_start)?;
448        dict.set_item("n_jobs", self.py_config.n_jobs)?;
449        dict.set_item("l1_ratio", self.py_config.l1_ratio)?;
450
451        Ok(dict.into())
452    }
453
454    /// Set parameters for this estimator (sklearn compatibility)
455    fn set_params(&mut self, kwargs: &Bound<'_, PyDict>) -> PyResult<()> {
456        // Update configuration parameters
457        if let Some(penalty) = kwargs.get_item("penalty")? {
458            let penalty_str: String = penalty.extract()?;
459            self.py_config.penalty = penalty_str;
460        }
461        if let Some(c) = kwargs.get_item("C")? {
462            self.py_config.c = c.extract()?;
463        }
464        if let Some(fit_intercept) = kwargs.get_item("fit_intercept")? {
465            self.py_config.fit_intercept = fit_intercept.extract()?;
466        }
467        if let Some(max_iter) = kwargs.get_item("max_iter")? {
468            self.py_config.max_iter = max_iter.extract()?;
469        }
470        if let Some(tol) = kwargs.get_item("tol")? {
471            self.py_config.tol = tol.extract()?;
472        }
473        if let Some(solver) = kwargs.get_item("solver")? {
474            let solver_str: String = solver.extract()?;
475            self.py_config.solver = solver_str;
476        }
477        if let Some(random_state) = kwargs.get_item("random_state")? {
478            self.py_config.random_state = random_state.extract()?;
479        }
480        if let Some(class_weight) = kwargs.get_item("class_weight")? {
481            let weight_str: Option<String> = class_weight.extract()?;
482            self.py_config.class_weight = weight_str;
483        }
484        if let Some(multi_class) = kwargs.get_item("multi_class")? {
485            let multi_class_str: String = multi_class.extract()?;
486            self.py_config.multi_class = multi_class_str;
487        }
488        if let Some(warm_start) = kwargs.get_item("warm_start")? {
489            self.py_config.warm_start = warm_start.extract()?;
490        }
491        if let Some(n_jobs) = kwargs.get_item("n_jobs")? {
492            self.py_config.n_jobs = n_jobs.extract()?;
493        }
494        if let Some(l1_ratio) = kwargs.get_item("l1_ratio")? {
495            self.py_config.l1_ratio = l1_ratio.extract()?;
496        }
497
498        // Clear fitted model since config changed
499        self.fitted_model = None;
500
501        Ok(())
502    }
503
504    /// String representation
505    fn __repr__(&self) -> String {
506        format!(
507            "LogisticRegression(penalty='{}', C={}, fit_intercept={}, max_iter={}, tol={}, solver='{}', random_state={:?}, multi_class='{}')",
508            self.py_config.penalty,
509            self.py_config.c,
510            self.py_config.fit_intercept,
511            self.py_config.max_iter,
512            self.py_config.tol,
513            self.py_config.solver,
514            self.py_config.random_state,
515            self.py_config.multi_class
516        )
517    }
518}