sklears_python/linear/logistic_regression.rs
1//! Python bindings for Logistic Regression
2//!
3//! This module provides Python bindings for Logistic Regression,
4//! offering scikit-learn compatible interfaces for binary classification
5//! using the sklears-linear crate.
6
7use super::common::*;
8use pyo3::types::PyDict;
9use pyo3::Bound;
10use sklears_core::traits::{Fit, Predict, PredictProba, Score, Trained};
11use sklears_linear::{LogisticRegression, LogisticRegressionConfig, Penalty, Solver};
12
13/// Python-specific configuration wrapper for LogisticRegression
14#[derive(Debug, Clone)]
15pub struct PyLogisticRegressionConfig {
16 pub penalty: String,
17 pub c: f64,
18 pub fit_intercept: bool,
19 pub max_iter: usize,
20 pub tol: f64,
21 pub solver: String,
22 pub random_state: Option<i32>,
23 pub class_weight: Option<String>,
24 pub multi_class: String,
25 pub warm_start: bool,
26 pub n_jobs: Option<i32>,
27 pub l1_ratio: Option<f64>,
28}
29
30impl Default for PyLogisticRegressionConfig {
31 fn default() -> Self {
32 Self {
33 penalty: "l2".to_string(),
34 c: 1.0,
35 fit_intercept: true,
36 max_iter: 100,
37 tol: 1e-4,
38 solver: "lbfgs".to_string(),
39 random_state: None,
40 class_weight: None,
41 multi_class: "auto".to_string(),
42 warm_start: false,
43 n_jobs: None,
44 l1_ratio: None,
45 }
46 }
47}
48
49impl From<PyLogisticRegressionConfig> for LogisticRegressionConfig {
50 fn from(py_config: PyLogisticRegressionConfig) -> Self {
51 // Convert penalty string to Penalty enum
52 let penalty = match py_config.penalty.as_str() {
53 "l1" => Penalty::L1(1.0 / py_config.c),
54 "l2" => Penalty::L2(1.0 / py_config.c),
55 "elasticnet" => Penalty::ElasticNet {
56 alpha: 1.0 / py_config.c,
57 l1_ratio: py_config.l1_ratio.unwrap_or(0.5),
58 },
59 _ => Penalty::L2(1.0 / py_config.c), // Default to L2
60 };
61
62 // Convert solver string to Solver enum
63 let solver = match py_config.solver.as_str() {
64 "lbfgs" => Solver::Lbfgs,
65 "sag" => Solver::Sag,
66 "saga" => Solver::Saga,
67 "newton-cg" => Solver::Newton,
68 _ => Solver::Auto, // Default to Auto
69 };
70
71 LogisticRegressionConfig {
72 penalty,
73 solver,
74 max_iter: py_config.max_iter,
75 tol: py_config.tol,
76 fit_intercept: py_config.fit_intercept,
77 random_state: py_config.random_state.map(|s| s as u64),
78 }
79 }
80}
81
82/// Logistic Regression (aka logit, MaxEnt) classifier.
83///
84/// In the multiclass case, the training algorithm uses the one-vs-rest (OvR)
85/// scheme if the 'multi_class' option is set to 'ovr', and uses the
86/// cross-entropy loss if the 'multi_class' option is set to 'multinomial'.
87/// (Currently the 'multinomial' option is supported only by the 'lbfgs',
88/// 'sag', 'saga' and 'newton-cg' solvers.)
89///
90/// This class implements regularized logistic regression using various solvers.
91/// **Note that regularization is applied by default**. It can handle both
92/// dense and sparse input. Use C-ordered arrays containing
93/// 64-bit floats for optimal performance; any other input format will be
94/// converted (and copied).
95///
96/// The 'newton-cg', 'sag', and 'lbfgs' solvers support only L2 regularization
97/// with primal formulation, or no regularization. The Elastic-Net regularization
98/// is only supported by the 'saga' solver.
99///
100/// Parameters
101/// ----------
102/// penalty : {'l1', 'l2', 'elasticnet'}, default='l2'
103/// Specify the norm of the penalty:
104///
105/// - 'l2': add a L2 penalty term and it is the default choice;
106/// - 'l1': add a L1 penalty term;
107/// - 'elasticnet': both L1 and L2 penalty terms are added.
108///
109/// tol : float, default=1e-4
110/// Tolerance for stopping criteria.
111///
112/// C : float, default=1.0
113/// Inverse of regularization strength; must be a positive float.
114/// Like in support vector machines, smaller values specify stronger
115/// regularization.
116///
117/// fit_intercept : bool, default=True
118/// Specifies if a constant (a.k.a. bias or intercept) should be
119/// added to the decision function.
120///
121/// class_weight : dict or 'balanced', default=None
122/// Weights associated with classes in the form ``{class_label: weight}``.
123/// If not given, all classes are supposed to have weight one.
124///
125/// The "balanced" mode uses the values of y to automatically adjust
126/// weights inversely proportional to class frequencies in the input data
127/// as ``n_samples / (n_classes * np.bincount(y))``.
128///
129/// random_state : int, default=None
130/// Used when ``solver`` == 'sag', 'saga' to shuffle the
131/// data. See :term:`Glossary <random_state>` for details.
132///
133/// solver : {'lbfgs', 'newton-cg', 'sag', 'saga'}, default='lbfgs'
134///
135/// Algorithm to use in the optimization problem. Default is 'lbfgs'.
136/// To choose a solver, you might want to consider the following aspects:
137///
138/// - For small datasets, 'lbfgs' is a good choice, whereas 'sag'
139/// and 'saga' are faster for large ones;
140/// - For multiclass problems, only 'newton-cg', 'sag', 'saga' and
141/// 'lbfgs' handle multinomial loss.
142///
143/// max_iter : int, default=100
144/// Maximum number of iterations taken for the solvers to converge.
145///
146/// multi_class : {'auto', 'ovr', 'multinomial'}, default='auto'
147/// If the option chosen is 'ovr', then a binary problem is fit for each
148/// label. For 'multinomial' the loss minimised is the multinomial loss fit
149/// across the entire probability distribution, *even when the data is
150/// binary*. 'auto' selects 'ovr' if the data is binary,
151/// and otherwise selects 'multinomial'.
152///
153/// warm_start : bool, default=False
154/// When set to True, reuse the solution of the previous call to fit as
155/// initialization, otherwise, just erase the previous solution.
156/// See :term:`the Glossary <warm_start>`.
157///
158/// n_jobs : int, default=None
159/// Number of CPU cores used when parallelizing over classes if
160/// multi_class='ovr'". ``None`` means 1 unless in a
161/// :obj:`joblib.parallel_backend` context. ``-1`` means using all
162/// processors. See :term:`Glossary <n_jobs>` for more details.
163///
164/// l1_ratio : float, default=None
165/// The Elastic-Net mixing parameter, with ``0 <= l1_ratio <= 1``. Only
166/// used if ``penalty='elasticnet'``. Setting ``l1_ratio=0`` is equivalent
167/// to using ``penalty='l2'``, while setting ``l1_ratio=1`` is equivalent
168/// to using ``penalty='l1'``. For ``0 < l1_ratio <1``, the penalty is a
169/// combination of L1 and L2.
170///
171/// Attributes
172/// ----------
173/// classes_ : ndarray of shape (n_classes, )
174/// A list of class labels known to the classifier.
175///
176/// coef_ : ndarray of shape (1, n_features) or (n_classes, n_features)
177/// Coefficient of the features in the decision function.
178///
179/// `coef_` is of shape (1, n_features) when the given problem is binary.
180///
181/// intercept_ : float or ndarray of shape (n_classes,)
182/// Intercept (a.k.a. bias) added to the decision function.
183///
184/// If `fit_intercept` is set to False, the intercept is set to zero.
185/// `intercept_` is of shape (1,) when the given problem is binary.
186///
187/// n_features_in_ : int
188/// Number of features seen during :term:`fit`.
189///
190/// Examples
191/// --------
192/// >>> from sklears_python import LogisticRegression
193/// >>> import numpy as np
194/// >>> X = np.array([[-1, -1], [-2, -1], [1, 1], [2, 1]])
195/// >>> y = np.array([0, 0, 1, 1])
196/// >>> clf = LogisticRegression(random_state=0).fit(X, y)
197/// >>> clf.predict(X[:2, :])
198/// array([0, 0])
199/// >>> clf.predict_proba(X[:2, :])
200/// array([[...]])
201/// >>> clf.score(X, y)
202/// 1.0
203///
204/// Notes
205/// -----
206/// The underlying implementation uses optimized solvers from sklears-linear.
207///
208/// References
209/// ----------
210/// L-BFGS-B -- Software for Large-scale Bound-constrained Optimization
211/// Ciyou Zhu, Richard Byrd, Jorge Nocedal and Jose Luis Morales.
212/// http://users.iems.northwestern.edu/~nocedal/lbfgsb.html
213///
214/// SAG -- Mark Schmidt, Nicolas Le Roux, and Francis Bach
215/// Minimizing Finite Sums with the Stochastic Average Gradient
216/// https://hal.inria.fr/hal-00860051/document
217///
218/// SAGA -- Defazio, A., Bach F. & Lacoste-Julien S. (2014).
219/// SAGA: A Fast Incremental Gradient Method With Support
220/// for Non-Strongly Convex Composite Objectives
221/// https://arxiv.org/abs/1407.0202
222#[pyclass(name = "LogisticRegression")]
223pub struct PyLogisticRegression {
224 py_config: PyLogisticRegressionConfig,
225 fitted_model: Option<LogisticRegression<Trained>>,
226}
227
228#[pymethods]
229impl PyLogisticRegression {
230 #[new]
231 #[allow(clippy::too_many_arguments)]
232 #[pyo3(signature = (penalty="l2", dual=false, tol=1e-4, c=1.0, fit_intercept=true, intercept_scaling=1.0, class_weight=None, random_state=None, solver="lbfgs", max_iter=100, multi_class="auto", verbose=0, warm_start=false, n_jobs=None, l1_ratio=None))]
233 fn new(
234 penalty: &str,
235 dual: bool,
236 tol: f64,
237 c: f64,
238 fit_intercept: bool,
239 intercept_scaling: f64,
240 class_weight: Option<&str>,
241 random_state: Option<i32>,
242 solver: &str,
243 max_iter: usize,
244 multi_class: &str,
245 verbose: i32,
246 warm_start: bool,
247 n_jobs: Option<i32>,
248 l1_ratio: Option<f64>,
249 ) -> Self {
250 // Note: Some parameters are sklearn-specific and don't directly map to our implementation
251 let _dual = dual;
252 let _intercept_scaling = intercept_scaling;
253 let _verbose = verbose;
254
255 let py_config = PyLogisticRegressionConfig {
256 penalty: penalty.to_string(),
257 c,
258 fit_intercept,
259 max_iter,
260 tol,
261 solver: solver.to_string(),
262 random_state,
263 class_weight: class_weight.map(|s| s.to_string()),
264 multi_class: multi_class.to_string(),
265 warm_start,
266 n_jobs,
267 l1_ratio,
268 };
269
270 Self {
271 py_config,
272 fitted_model: None,
273 }
274 }
275
276 /// Fit the logistic regression model
277 fn fit(&mut self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<()> {
278 let x_array = pyarray_to_core_array2(x)?;
279 let y_array = pyarray_to_core_array1(y)?;
280
281 // Validate input arrays
282 validate_fit_arrays(&x_array, &y_array)?;
283
284 // Create sklears-linear model with Logistic Regression configuration
285 let model = LogisticRegression::new()
286 .max_iter(self.py_config.max_iter)
287 .fit_intercept(self.py_config.fit_intercept);
288
289 // Apply penalty if specified
290 let model = match self.py_config.penalty.as_str() {
291 "l1" => model.penalty(Penalty::L1(1.0 / self.py_config.c)),
292 "l2" => model.penalty(Penalty::L2(1.0 / self.py_config.c)),
293 "elasticnet" => model.penalty(Penalty::ElasticNet {
294 alpha: 1.0 / self.py_config.c,
295 l1_ratio: self.py_config.l1_ratio.unwrap_or(0.5),
296 }),
297 _ => model, // Default (no additional penalty)
298 };
299
300 // Apply solver if specified
301 let model = match self.py_config.solver.as_str() {
302 "lbfgs" => model.solver(Solver::Lbfgs),
303 "sag" => model.solver(Solver::Sag),
304 "saga" => model.solver(Solver::Saga),
305 "newton-cg" => model.solver(Solver::Newton),
306 _ => model.solver(Solver::Auto),
307 };
308
309 // Apply random state if specified
310 let model = if let Some(rs) = self.py_config.random_state {
311 model.random_state(rs as u64)
312 } else {
313 model
314 };
315
316 // Fit the model using sklears-linear's implementation
317 match model.fit(&x_array, &y_array) {
318 Ok(fitted_model) => {
319 self.fitted_model = Some(fitted_model);
320 Ok(())
321 }
322 Err(e) => Err(PyValueError::new_err(format!(
323 "Failed to fit Logistic Regression model: {:?}",
324 e
325 ))),
326 }
327 }
328
329 /// Predict class labels for samples
330 fn predict(&self, py: Python<'_>, x: PyReadonlyArray2<f64>) -> PyResult<Py<PyArray1<f64>>> {
331 let fitted = self
332 .fitted_model
333 .as_ref()
334 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
335
336 let x_array = pyarray_to_core_array2(x)?;
337 validate_predict_array(&x_array)?;
338
339 match fitted.predict(&x_array) {
340 Ok(predictions) => Ok(core_array1_to_py(py, &predictions)),
341 Err(e) => Err(PyValueError::new_err(format!("Prediction failed: {:?}", e))),
342 }
343 }
344
345 /// Predict class probabilities for samples
346 fn predict_proba(
347 &self,
348 py: Python<'_>,
349 x: PyReadonlyArray2<f64>,
350 ) -> PyResult<Py<PyArray2<f64>>> {
351 let fitted = self
352 .fitted_model
353 .as_ref()
354 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
355
356 let x_array = pyarray_to_core_array2(x)?;
357 validate_predict_array(&x_array)?;
358
359 match fitted.predict_proba(&x_array) {
360 Ok(probabilities) => core_array2_to_py(py, &probabilities),
361 Err(e) => Err(PyValueError::new_err(format!(
362 "Probability prediction failed: {:?}",
363 e
364 ))),
365 }
366 }
367
368 /// Get model coefficients
369 #[getter]
370 fn coef_(&self, py: Python<'_>) -> PyResult<Py<PyArray1<f64>>> {
371 let fitted = self
372 .fitted_model
373 .as_ref()
374 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
375
376 Ok(core_array1_to_py(py, fitted.coef()))
377 }
378
379 /// Get model intercept
380 #[getter]
381 fn intercept_(&self) -> PyResult<f64> {
382 let fitted = self
383 .fitted_model
384 .as_ref()
385 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
386
387 Ok(fitted.intercept().unwrap_or(0.0))
388 }
389
390 /// Get unique class labels
391 #[getter]
392 fn classes_(&self, py: Python<'_>) -> PyResult<Py<PyArray1<f64>>> {
393 let fitted = self
394 .fitted_model
395 .as_ref()
396 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
397
398 Ok(core_array1_to_py(py, fitted.classes()))
399 }
400
401 /// Calculate accuracy score
402 fn score(&self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<f64> {
403 let fitted = self
404 .fitted_model
405 .as_ref()
406 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
407
408 let x_array = pyarray_to_core_array2(x)?;
409 let y_array = pyarray_to_core_array1(y)?;
410
411 match fitted.score(&x_array, &y_array) {
412 Ok(score) => Ok(score),
413 Err(e) => Err(PyValueError::new_err(format!(
414 "Score calculation failed: {:?}",
415 e
416 ))),
417 }
418 }
419
420 /// Get number of features
421 #[getter]
422 fn n_features_in_(&self) -> PyResult<usize> {
423 let fitted = self
424 .fitted_model
425 .as_ref()
426 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
427
428 // Infer number of features from coefficient array length
429 Ok(fitted.coef().len())
430 }
431
432 /// Return parameters for this estimator (sklearn compatibility)
433 fn get_params(&self, py: Python<'_>, deep: Option<bool>) -> PyResult<Py<PyDict>> {
434 let _deep = deep.unwrap_or(true);
435
436 let dict = PyDict::new(py);
437
438 dict.set_item("penalty", &self.py_config.penalty)?;
439 dict.set_item("C", self.py_config.c)?;
440 dict.set_item("fit_intercept", self.py_config.fit_intercept)?;
441 dict.set_item("max_iter", self.py_config.max_iter)?;
442 dict.set_item("tol", self.py_config.tol)?;
443 dict.set_item("solver", &self.py_config.solver)?;
444 dict.set_item("random_state", self.py_config.random_state)?;
445 dict.set_item("class_weight", &self.py_config.class_weight)?;
446 dict.set_item("multi_class", &self.py_config.multi_class)?;
447 dict.set_item("warm_start", self.py_config.warm_start)?;
448 dict.set_item("n_jobs", self.py_config.n_jobs)?;
449 dict.set_item("l1_ratio", self.py_config.l1_ratio)?;
450
451 Ok(dict.into())
452 }
453
454 /// Set parameters for this estimator (sklearn compatibility)
455 fn set_params(&mut self, kwargs: &Bound<'_, PyDict>) -> PyResult<()> {
456 // Update configuration parameters
457 if let Some(penalty) = kwargs.get_item("penalty")? {
458 let penalty_str: String = penalty.extract()?;
459 self.py_config.penalty = penalty_str;
460 }
461 if let Some(c) = kwargs.get_item("C")? {
462 self.py_config.c = c.extract()?;
463 }
464 if let Some(fit_intercept) = kwargs.get_item("fit_intercept")? {
465 self.py_config.fit_intercept = fit_intercept.extract()?;
466 }
467 if let Some(max_iter) = kwargs.get_item("max_iter")? {
468 self.py_config.max_iter = max_iter.extract()?;
469 }
470 if let Some(tol) = kwargs.get_item("tol")? {
471 self.py_config.tol = tol.extract()?;
472 }
473 if let Some(solver) = kwargs.get_item("solver")? {
474 let solver_str: String = solver.extract()?;
475 self.py_config.solver = solver_str;
476 }
477 if let Some(random_state) = kwargs.get_item("random_state")? {
478 self.py_config.random_state = random_state.extract()?;
479 }
480 if let Some(class_weight) = kwargs.get_item("class_weight")? {
481 let weight_str: Option<String> = class_weight.extract()?;
482 self.py_config.class_weight = weight_str;
483 }
484 if let Some(multi_class) = kwargs.get_item("multi_class")? {
485 let multi_class_str: String = multi_class.extract()?;
486 self.py_config.multi_class = multi_class_str;
487 }
488 if let Some(warm_start) = kwargs.get_item("warm_start")? {
489 self.py_config.warm_start = warm_start.extract()?;
490 }
491 if let Some(n_jobs) = kwargs.get_item("n_jobs")? {
492 self.py_config.n_jobs = n_jobs.extract()?;
493 }
494 if let Some(l1_ratio) = kwargs.get_item("l1_ratio")? {
495 self.py_config.l1_ratio = l1_ratio.extract()?;
496 }
497
498 // Clear fitted model since config changed
499 self.fitted_model = None;
500
501 Ok(())
502 }
503
504 /// String representation
505 fn __repr__(&self) -> String {
506 format!(
507 "LogisticRegression(penalty='{}', C={}, fit_intercept={}, max_iter={}, tol={}, solver='{}', random_state={:?}, multi_class='{}')",
508 self.py_config.penalty,
509 self.py_config.c,
510 self.py_config.fit_intercept,
511 self.py_config.max_iter,
512 self.py_config.tol,
513 self.py_config.solver,
514 self.py_config.random_state,
515 self.py_config.multi_class
516 )
517 }
518}