Skip to main content

sklears_python/linear/
logistic_regression.rs

1//! Python bindings for Logistic Regression
2//!
3//! This module provides Python bindings for Logistic Regression,
4//! offering scikit-learn compatible interfaces for binary classification
5//! using the sklears-linear crate.
6
7use super::common::*;
8use pyo3::types::PyDict;
9use pyo3::Bound;
10use sklears_core::traits::{Fit, Predict, PredictProba, Score, Trained};
11use sklears_linear::{LogisticRegression, LogisticRegressionConfig, Penalty, Solver};
12
13/// Python-specific configuration wrapper for LogisticRegression
14#[derive(Debug, Clone)]
15pub struct PyLogisticRegressionConfig {
16    pub penalty: String,
17    pub c: f64,
18    pub fit_intercept: bool,
19    pub max_iter: usize,
20    pub tol: f64,
21    pub solver: String,
22    pub random_state: Option<i32>,
23    pub class_weight: Option<String>,
24    pub multi_class: String,
25    pub warm_start: bool,
26    pub n_jobs: Option<i32>,
27    pub l1_ratio: Option<f64>,
28}
29
30impl Default for PyLogisticRegressionConfig {
31    fn default() -> Self {
32        Self {
33            penalty: "l2".to_string(),
34            c: 1.0,
35            fit_intercept: true,
36            max_iter: 100,
37            tol: 1e-4,
38            solver: "lbfgs".to_string(),
39            random_state: None,
40            class_weight: None,
41            multi_class: "auto".to_string(),
42            warm_start: false,
43            n_jobs: None,
44            l1_ratio: None,
45        }
46    }
47}
48
49impl From<PyLogisticRegressionConfig> for LogisticRegressionConfig {
50    fn from(py_config: PyLogisticRegressionConfig) -> Self {
51        // Convert penalty string to Penalty enum
52        let penalty = match py_config.penalty.as_str() {
53            "l1" => Penalty::L1(1.0 / py_config.c),
54            "l2" => Penalty::L2(1.0 / py_config.c),
55            "elasticnet" => Penalty::ElasticNet {
56                alpha: 1.0 / py_config.c,
57                l1_ratio: py_config.l1_ratio.unwrap_or(0.5),
58            },
59            _ => Penalty::L2(1.0 / py_config.c), // Default to L2
60        };
61
62        // Convert solver string to Solver enum
63        let solver = match py_config.solver.as_str() {
64            "lbfgs" => Solver::Lbfgs,
65            "sag" => Solver::Sag,
66            "saga" => Solver::Saga,
67            "newton-cg" => Solver::Newton,
68            _ => Solver::Auto, // Default to Auto
69        };
70
71        LogisticRegressionConfig {
72            penalty,
73            solver,
74            max_iter: py_config.max_iter,
75            tol: py_config.tol,
76            fit_intercept: py_config.fit_intercept,
77            random_state: py_config.random_state.map(|s| s as u64),
78        }
79    }
80}
81
82/// Logistic Regression (aka logit, MaxEnt) classifier.
83///
84/// In the multiclass case, the training algorithm uses the one-vs-rest (OvR)
85/// scheme if the 'multi_class' option is set to 'ovr', and uses the
86/// cross-entropy loss if the 'multi_class' option is set to 'multinomial'.
87///
88/// This class implements regularized logistic regression using various solvers.
89/// **Note that regularization is applied by default**.
90///
91/// # Parameters
92///
93/// - `penalty` - One of "l1", "l2", "elasticnet". Default: "l2"
94/// - `tol` - Tolerance for stopping criteria. Default: 1e-4
95/// - `c` - Inverse of regularization strength. Default: 1.0
96/// - `fit_intercept` - Whether to add bias term. Default: true
97/// - `solver` - One of "lbfgs", "newton-cg", "sag", "saga". Default: "lbfgs"
98/// - `max_iter` - Maximum iterations. Default: 100
99/// - `multi_class` - One of "auto", "ovr", "multinomial". Default: "auto"
100/// - `random_state` - Random seed for reproducibility
101/// - `l1_ratio` - Elastic-Net mixing parameter (0 to 1)
102///
103/// # References
104///
105/// - L-BFGS-B: <http://users.iems.northwestern.edu/~nocedal/lbfgsb.html>
106/// - SAG: <https://hal.inria.fr/hal-00860051/document>
107/// - SAGA: <https://arxiv.org/abs/1407.0202>
108#[pyclass(name = "LogisticRegression")]
109pub struct PyLogisticRegression {
110    py_config: PyLogisticRegressionConfig,
111    fitted_model: Option<LogisticRegression<Trained>>,
112}
113
114#[pymethods]
115impl PyLogisticRegression {
116    #[new]
117    #[allow(clippy::too_many_arguments)]
118    #[pyo3(signature = (penalty="l2", dual=false, tol=1e-4, c=1.0, fit_intercept=true, intercept_scaling=1.0, class_weight=None, random_state=None, solver="lbfgs", max_iter=100, multi_class="auto", verbose=0, warm_start=false, n_jobs=None, l1_ratio=None))]
119    fn new(
120        penalty: &str,
121        dual: bool,
122        tol: f64,
123        c: f64,
124        fit_intercept: bool,
125        intercept_scaling: f64,
126        class_weight: Option<&str>,
127        random_state: Option<i32>,
128        solver: &str,
129        max_iter: usize,
130        multi_class: &str,
131        verbose: i32,
132        warm_start: bool,
133        n_jobs: Option<i32>,
134        l1_ratio: Option<f64>,
135    ) -> Self {
136        // Note: Some parameters are sklearn-specific and don't directly map to our implementation
137        let _dual = dual;
138        let _intercept_scaling = intercept_scaling;
139        let _verbose = verbose;
140
141        let py_config = PyLogisticRegressionConfig {
142            penalty: penalty.to_string(),
143            c,
144            fit_intercept,
145            max_iter,
146            tol,
147            solver: solver.to_string(),
148            random_state,
149            class_weight: class_weight.map(|s| s.to_string()),
150            multi_class: multi_class.to_string(),
151            warm_start,
152            n_jobs,
153            l1_ratio,
154        };
155
156        Self {
157            py_config,
158            fitted_model: None,
159        }
160    }
161
162    /// Fit the logistic regression model
163    fn fit(&mut self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<()> {
164        let x_array = pyarray_to_core_array2(x)?;
165        let y_array = pyarray_to_core_array1(y)?;
166
167        // Validate input arrays
168        validate_fit_arrays(&x_array, &y_array)?;
169
170        // Create sklears-linear model with Logistic Regression configuration
171        let model = LogisticRegression::new()
172            .max_iter(self.py_config.max_iter)
173            .fit_intercept(self.py_config.fit_intercept);
174
175        // Apply penalty if specified
176        let model = match self.py_config.penalty.as_str() {
177            "l1" => model.penalty(Penalty::L1(1.0 / self.py_config.c)),
178            "l2" => model.penalty(Penalty::L2(1.0 / self.py_config.c)),
179            "elasticnet" => model.penalty(Penalty::ElasticNet {
180                alpha: 1.0 / self.py_config.c,
181                l1_ratio: self.py_config.l1_ratio.unwrap_or(0.5),
182            }),
183            _ => model, // Default (no additional penalty)
184        };
185
186        // Apply solver if specified
187        let model = match self.py_config.solver.as_str() {
188            "lbfgs" => model.solver(Solver::Lbfgs),
189            "sag" => model.solver(Solver::Sag),
190            "saga" => model.solver(Solver::Saga),
191            "newton-cg" => model.solver(Solver::Newton),
192            _ => model.solver(Solver::Auto),
193        };
194
195        // Apply random state if specified
196        let model = if let Some(rs) = self.py_config.random_state {
197            model.random_state(rs as u64)
198        } else {
199            model
200        };
201
202        // Fit the model using sklears-linear's implementation
203        match model.fit(&x_array, &y_array) {
204            Ok(fitted_model) => {
205                self.fitted_model = Some(fitted_model);
206                Ok(())
207            }
208            Err(e) => Err(PyValueError::new_err(format!(
209                "Failed to fit Logistic Regression model: {:?}",
210                e
211            ))),
212        }
213    }
214
215    /// Predict class labels for samples
216    fn predict(&self, py: Python<'_>, x: PyReadonlyArray2<f64>) -> PyResult<Py<PyArray1<f64>>> {
217        let fitted = self
218            .fitted_model
219            .as_ref()
220            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
221
222        let x_array = pyarray_to_core_array2(x)?;
223        validate_predict_array(&x_array)?;
224
225        match fitted.predict(&x_array) {
226            Ok(predictions) => Ok(core_array1_to_py(py, &predictions)),
227            Err(e) => Err(PyValueError::new_err(format!("Prediction failed: {:?}", e))),
228        }
229    }
230
231    /// Predict class probabilities for samples
232    fn predict_proba(
233        &self,
234        py: Python<'_>,
235        x: PyReadonlyArray2<f64>,
236    ) -> PyResult<Py<PyArray2<f64>>> {
237        let fitted = self
238            .fitted_model
239            .as_ref()
240            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
241
242        let x_array = pyarray_to_core_array2(x)?;
243        validate_predict_array(&x_array)?;
244
245        match fitted.predict_proba(&x_array) {
246            Ok(probabilities) => core_array2_to_py(py, &probabilities),
247            Err(e) => Err(PyValueError::new_err(format!(
248                "Probability prediction failed: {:?}",
249                e
250            ))),
251        }
252    }
253
254    /// Get model coefficients
255    #[getter]
256    fn coef_(&self, py: Python<'_>) -> PyResult<Py<PyArray1<f64>>> {
257        let fitted = self
258            .fitted_model
259            .as_ref()
260            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
261
262        Ok(core_array1_to_py(py, fitted.coef()))
263    }
264
265    /// Get model intercept
266    #[getter]
267    fn intercept_(&self) -> PyResult<f64> {
268        let fitted = self
269            .fitted_model
270            .as_ref()
271            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
272
273        Ok(fitted.intercept().unwrap_or(0.0))
274    }
275
276    /// Get unique class labels
277    #[getter]
278    fn classes_(&self, py: Python<'_>) -> PyResult<Py<PyArray1<f64>>> {
279        let fitted = self
280            .fitted_model
281            .as_ref()
282            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
283
284        Ok(core_array1_to_py(py, fitted.classes()))
285    }
286
287    /// Calculate accuracy score
288    fn score(&self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<f64> {
289        let fitted = self
290            .fitted_model
291            .as_ref()
292            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
293
294        let x_array = pyarray_to_core_array2(x)?;
295        let y_array = pyarray_to_core_array1(y)?;
296
297        match fitted.score(&x_array, &y_array) {
298            Ok(score) => Ok(score),
299            Err(e) => Err(PyValueError::new_err(format!(
300                "Score calculation failed: {:?}",
301                e
302            ))),
303        }
304    }
305
306    /// Get number of features
307    #[getter]
308    fn n_features_in_(&self) -> PyResult<usize> {
309        let fitted = self
310            .fitted_model
311            .as_ref()
312            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
313
314        // Infer number of features from coefficient array length
315        Ok(fitted.coef().len())
316    }
317
318    /// Return parameters for this estimator (sklearn compatibility)
319    fn get_params(&self, py: Python<'_>, deep: Option<bool>) -> PyResult<Py<PyDict>> {
320        let _deep = deep.unwrap_or(true);
321
322        let dict = PyDict::new(py);
323
324        dict.set_item("penalty", &self.py_config.penalty)?;
325        dict.set_item("C", self.py_config.c)?;
326        dict.set_item("fit_intercept", self.py_config.fit_intercept)?;
327        dict.set_item("max_iter", self.py_config.max_iter)?;
328        dict.set_item("tol", self.py_config.tol)?;
329        dict.set_item("solver", &self.py_config.solver)?;
330        dict.set_item("random_state", self.py_config.random_state)?;
331        dict.set_item("class_weight", &self.py_config.class_weight)?;
332        dict.set_item("multi_class", &self.py_config.multi_class)?;
333        dict.set_item("warm_start", self.py_config.warm_start)?;
334        dict.set_item("n_jobs", self.py_config.n_jobs)?;
335        dict.set_item("l1_ratio", self.py_config.l1_ratio)?;
336
337        Ok(dict.into())
338    }
339
340    /// Set parameters for this estimator (sklearn compatibility)
341    fn set_params(&mut self, kwargs: &Bound<'_, PyDict>) -> PyResult<()> {
342        // Update configuration parameters
343        if let Some(penalty) = kwargs.get_item("penalty")? {
344            let penalty_str: String = penalty.extract()?;
345            self.py_config.penalty = penalty_str;
346        }
347        if let Some(c) = kwargs.get_item("C")? {
348            self.py_config.c = c.extract()?;
349        }
350        if let Some(fit_intercept) = kwargs.get_item("fit_intercept")? {
351            self.py_config.fit_intercept = fit_intercept.extract()?;
352        }
353        if let Some(max_iter) = kwargs.get_item("max_iter")? {
354            self.py_config.max_iter = max_iter.extract()?;
355        }
356        if let Some(tol) = kwargs.get_item("tol")? {
357            self.py_config.tol = tol.extract()?;
358        }
359        if let Some(solver) = kwargs.get_item("solver")? {
360            let solver_str: String = solver.extract()?;
361            self.py_config.solver = solver_str;
362        }
363        if let Some(random_state) = kwargs.get_item("random_state")? {
364            self.py_config.random_state = random_state.extract()?;
365        }
366        if let Some(class_weight) = kwargs.get_item("class_weight")? {
367            let weight_str: Option<String> = class_weight.extract()?;
368            self.py_config.class_weight = weight_str;
369        }
370        if let Some(multi_class) = kwargs.get_item("multi_class")? {
371            let multi_class_str: String = multi_class.extract()?;
372            self.py_config.multi_class = multi_class_str;
373        }
374        if let Some(warm_start) = kwargs.get_item("warm_start")? {
375            self.py_config.warm_start = warm_start.extract()?;
376        }
377        if let Some(n_jobs) = kwargs.get_item("n_jobs")? {
378            self.py_config.n_jobs = n_jobs.extract()?;
379        }
380        if let Some(l1_ratio) = kwargs.get_item("l1_ratio")? {
381            self.py_config.l1_ratio = l1_ratio.extract()?;
382        }
383
384        // Clear fitted model since config changed
385        self.fitted_model = None;
386
387        Ok(())
388    }
389
390    /// String representation
391    fn __repr__(&self) -> String {
392        format!(
393            "LogisticRegression(penalty='{}', C={}, fit_intercept={}, max_iter={}, tol={}, solver='{}', random_state={:?}, multi_class='{}')",
394            self.py_config.penalty,
395            self.py_config.c,
396            self.py_config.fit_intercept,
397            self.py_config.max_iter,
398            self.py_config.tol,
399            self.py_config.solver,
400            self.py_config.random_state,
401            self.py_config.multi_class
402        )
403    }
404}