sklears_python/linear/logistic_regression.rs
1//! Python bindings for Logistic Regression
2//!
3//! This module provides Python bindings for Logistic Regression,
4//! offering scikit-learn compatible interfaces for binary and multiclass classification.
5//!
6//! Note: This is a basic implementation using manual logistic regression until
7//! the sklears-linear LogisticRegression feature compilation issues are resolved.
8
9use super::common::*;
10use numpy::IntoPyArray;
11use pyo3::types::PyDict;
12use pyo3::Bound;
13use scirs2_autograd::ndarray::{s, Array1, Array2, Axis};
14use scirs2_core::random::{thread_rng, Rng};
15
16/// Python-specific configuration wrapper for LogisticRegression
17#[derive(Debug, Clone)]
18pub struct PyLogisticRegressionConfig {
19 pub penalty: String,
20 pub c: f64,
21 pub fit_intercept: bool,
22 pub max_iter: usize,
23 pub tol: f64,
24 pub solver: String,
25 pub random_state: Option<i32>,
26 pub class_weight: Option<String>,
27 pub multi_class: String,
28 pub warm_start: bool,
29 pub n_jobs: Option<i32>,
30 pub l1_ratio: Option<f64>,
31}
32
33impl Default for PyLogisticRegressionConfig {
34 fn default() -> Self {
35 Self {
36 penalty: "l2".to_string(),
37 c: 1.0,
38 fit_intercept: true,
39 max_iter: 100,
40 tol: 1e-4,
41 solver: "lbfgs".to_string(),
42 random_state: None,
43 class_weight: None,
44 multi_class: "auto".to_string(),
45 warm_start: false,
46 n_jobs: None,
47 l1_ratio: None,
48 }
49 }
50}
51
52/// Basic logistic regression implementation
53#[derive(Debug, Clone)]
54struct BasicLogisticRegression {
55 config: PyLogisticRegressionConfig,
56 coef_: Array1<f64>,
57 intercept_: f64,
58 classes_: Array1<f64>,
59 n_features_: usize,
60}
61
62impl BasicLogisticRegression {
63 fn new(config: PyLogisticRegressionConfig) -> Self {
64 Self {
65 config,
66 coef_: Array1::zeros(1),
67 intercept_: 0.0,
68 classes_: Array1::zeros(1),
69 n_features_: 0,
70 }
71 }
72
73 fn sigmoid(z: f64) -> f64 {
74 1.0 / (1.0 + (-z).exp())
75 }
76
77 fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<(), SklearsPythonError> {
78 let n_samples = x.nrows();
79 let n_features = x.ncols();
80 self.n_features_ = n_features;
81
82 if n_samples != y.len() {
83 return Err(SklearsPythonError::ValidationError(
84 "X and y have incompatible shapes".to_string(),
85 ));
86 }
87
88 // Find unique classes
89 let mut classes: Vec<f64> = y.iter().cloned().collect();
90 classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
91 classes.dedup();
92 self.classes_ = Array1::from_vec(classes.clone());
93
94 // For now, only support binary classification
95 if classes.len() != 2 {
96 return Err(SklearsPythonError::ValidationError(
97 "Currently only binary classification is supported".to_string(),
98 ));
99 }
100
101 // Map classes to 0 and 1
102 let y_mapped: Array1<f64> = y.mapv(|val| if val == classes[0] { 0.0 } else { 1.0 });
103
104 // Add intercept column if needed
105 let x_design = if self.config.fit_intercept {
106 let mut x_new = Array2::ones((n_samples, n_features + 1));
107 x_new.slice_mut(s![.., 1..]).assign(x);
108 x_new
109 } else {
110 x.clone()
111 };
112
113 let n_params = if self.config.fit_intercept {
114 n_features + 1
115 } else {
116 n_features
117 };
118
119 // Initialize weights
120 let mut rng = thread_rng();
121 let mut weights = Array1::from_shape_fn(n_params, |_| rng.gen::<f64>() * 0.01);
122
123 // Gradient descent
124 let learning_rate = 0.01;
125 for _iter in 0..self.config.max_iter {
126 let mut total_loss = 0.0;
127 let mut gradient: Array1<f64> = Array1::zeros(n_params);
128
129 for i in 0..n_samples {
130 let xi = x_design.row(i);
131 let yi = y_mapped[i];
132
133 let z = xi.dot(&weights);
134 let prediction = Self::sigmoid(z);
135
136 // Log loss contribution
137 let loss = if yi == 1.0 {
138 -prediction.ln()
139 } else {
140 -(1.0 - prediction).ln()
141 };
142 total_loss += loss;
143
144 // Gradient contribution
145 let error = prediction - yi;
146 for j in 0..n_params {
147 gradient[j] += error * xi[j];
148 }
149 }
150
151 // Apply L2 regularization if configured
152 if self.config.penalty == "l2" && self.config.c > 0.0 {
153 let reg_strength = 1.0 / self.config.c;
154 for j in 0..n_params {
155 // Don't regularize intercept
156 if !self.config.fit_intercept || j > 0 {
157 gradient[j] += reg_strength * weights[j];
158 total_loss += 0.5 * reg_strength * weights[j] * weights[j];
159 }
160 }
161 }
162
163 // Update weights
164 for j in 0..n_params {
165 weights[j] -= learning_rate * gradient[j] / n_samples as f64;
166 }
167
168 // Check convergence
169 let avg_loss = total_loss / n_samples as f64;
170 if avg_loss < self.config.tol {
171 break;
172 }
173 }
174
175 // Extract coefficients and intercept
176 if self.config.fit_intercept {
177 self.intercept_ = weights[0];
178 self.coef_ = weights.slice(s![1..]).to_owned();
179 } else {
180 self.intercept_ = 0.0;
181 self.coef_ = weights;
182 }
183
184 Ok(())
185 }
186
187 fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>, SklearsPythonError> {
188 if x.ncols() != self.n_features_ {
189 return Err(SklearsPythonError::ValidationError(format!(
190 "X has {} features, but model expects {} features",
191 x.ncols(),
192 self.n_features_
193 )));
194 }
195
196 let probabilities = self.predict_proba(x)?;
197 let predictions = probabilities
198 .axis_iter(Axis(0))
199 .map(|row| {
200 if row[1] > 0.5 {
201 self.classes_[1]
202 } else {
203 self.classes_[0]
204 }
205 })
206 .collect::<Vec<f64>>();
207
208 Ok(Array1::from_vec(predictions))
209 }
210
211 fn predict_proba(&self, x: &Array2<f64>) -> Result<Array2<f64>, SklearsPythonError> {
212 if x.ncols() != self.n_features_ {
213 return Err(SklearsPythonError::ValidationError(format!(
214 "X has {} features, but model expects {} features",
215 x.ncols(),
216 self.n_features_
217 )));
218 }
219
220 let n_samples = x.nrows();
221 let mut probabilities = Array2::zeros((n_samples, 2));
222
223 for i in 0..n_samples {
224 let xi = x.row(i);
225 let z = xi.dot(&self.coef_) + self.intercept_;
226 let prob_class_1 = Self::sigmoid(z);
227 let prob_class_0 = 1.0 - prob_class_1;
228
229 probabilities[[i, 0]] = prob_class_0;
230 probabilities[[i, 1]] = prob_class_1;
231 }
232
233 Ok(probabilities)
234 }
235
236 fn score(&self, x: &Array2<f64>, y: &Array1<f64>) -> Result<f64, SklearsPythonError> {
237 let predictions = self.predict(x)?;
238 let correct = y
239 .iter()
240 .zip(predictions.iter())
241 .filter(|(&true_val, &pred_val)| (true_val - pred_val).abs() < 1e-6)
242 .count();
243
244 Ok(correct as f64 / y.len() as f64)
245 }
246}
247
248/// Logistic Regression (aka logit, MaxEnt) classifier.
249///
250/// In the multiclass case, the training algorithm uses the one-vs-rest (OvR)
251/// scheme if the 'multi_class' option is set to 'ovr', and uses the
252/// cross-entropy loss if the 'multi_class' option is set to 'multinomial'.
253/// (Currently the 'multinomial' option is supported only by the 'lbfgs',
254/// 'sag', 'saga' and 'newton-cg' solvers.)
255///
256/// This class implements regularized logistic regression using the
257/// 'liblinear' library, 'newton-cg', 'sag', 'saga' and 'lbfgs' solvers.
258/// **Note that regularization is applied by default**. It can handle both
259/// dense and sparse input. Use C-ordered arrays or CSR matrices containing
260/// 64-bit floats for optimal performance; any other input format will be
261/// converted (and copied).
262///
263/// The 'newton-cg', 'sag', and 'lbfgs' solvers support only L2 regularization
264/// with primal formulation, or no regularization. The 'liblinear' solver
265/// supports both L1 and L2 regularization, with a dual formulation only for
266/// the L2 penalty. The Elastic-Net regularization is only supported by the
267/// 'saga' solver.
268///
269/// Parameters
270/// ----------
271/// penalty : {'l1', 'l2', 'elasticnet', None}, default='l2'
272/// Specify the norm of the penalty:
273///
274/// - None: no penalty is added;
275/// - 'l2': add a L2 penalty term and it is the default choice;
276/// - 'l1': add a L1 penalty term;
277/// - 'elasticnet': both L1 and L2 penalty terms are added.
278///
279/// dual : bool, default=False
280/// Dual or primal formulation. Dual formulation is only implemented for
281/// l2 penalty with liblinear solver. Prefer dual=False when
282/// n_samples > n_features.
283///
284/// tol : float, default=1e-4
285/// Tolerance for stopping criteria.
286///
287/// C : float, default=1.0
288/// Inverse of regularization strength; must be a positive float.
289/// Like in support vector machines, smaller values specify stronger
290/// regularization.
291///
292/// fit_intercept : bool, default=True
293/// Specifies if a constant (a.k.a. bias or intercept) should be
294/// added to the decision function.
295///
296/// intercept_scaling : float, default=1
297/// Useful only when the solver 'liblinear' is used
298/// and self.fit_intercept is set to True. In this case, x becomes
299/// [x, self.intercept_scaling],
300/// i.e. a "synthetic" feature with constant value equal to
301/// intercept_scaling is appended to the instance vector.
302/// The intercept becomes intercept_scaling * synthetic_feature_weight.
303///
304/// Note! the synthetic feature weight is subject to l1/l2 regularization
305/// as all other features.
306/// To lessen the effect of regularization on synthetic feature weight
307/// (and therefore on the intercept) intercept_scaling has to be increased.
308///
309/// class_weight : dict or 'balanced', default=None
310/// Weights associated with classes in the form ``{class_label: weight}``.
311/// If not given, all classes are supposed to have weight one.
312///
313/// The "balanced" mode uses the values of y to automatically adjust
314/// weights inversely proportional to class frequencies in the input data
315/// as ``n_samples / (n_classes * np.bincount(y))``.
316///
317/// Note that these weights will be multiplied with sample_weight (passed
318/// through the fit method) if sample_weight is specified.
319///
320/// random_state : int, RandomState instance, default=None
321/// Used when ``solver`` == 'sag', 'saga' or 'liblinear' to shuffle the
322/// data. See :term:`Glossary <random_state>` for details.
323///
324/// solver : {'lbfgs', 'liblinear', 'newton-cg', 'newton-cholesky', 'sag', 'saga'}, \
325/// default='lbfgs'
326///
327/// Algorithm to use in the optimization problem. Default is 'lbfgs'.
328/// To choose a solver, you might want to consider the following aspects:
329///
330/// - For small datasets, 'liblinear' is a good choice, whereas 'sag'
331/// and 'saga' are faster for large ones;
332/// - For multiclass problems, only 'newton-cg', 'sag', 'saga' and
333/// 'lbfgs' handle multinomial loss;
334/// - 'liblinear' is limited to one-versus-rest schemes.
335///
336/// max_iter : int, default=100
337/// Maximum number of iterations taken for the solvers to converge.
338///
339/// multi_class : {'auto', 'ovr', 'multinomial'}, default='auto'
340/// If the option chosen is 'ovr', then a binary problem is fit for each
341/// label. For 'multinomial' the loss minimised is the multinomial loss fit
342/// across the entire probability distribution, *even when the data is
343/// binary*. 'multinomial' is unavailable when solver='liblinear'.
344/// 'auto' selects 'ovr' if the data is binary, or if solver='liblinear',
345/// and otherwise selects 'multinomial'.
346///
347/// verbose : int, default=0
348/// For the liblinear and lbfgs solvers set verbose to any positive
349/// number for verbosity.
350///
351/// warm_start : bool, default=False
352/// When set to True, reuse the solution of the previous call to fit as
353/// initialization, otherwise, just erase the previous solution.
354/// Useless for liblinear solver. See :term:`the Glossary <warm_start>`.
355///
356/// n_jobs : int, default=None
357/// Number of CPU cores used when parallelizing over classes if
358/// multi_class='ovr'". This parameter is ignored when the ``solver``
359/// is set to 'liblinear' regardless of whether 'multi_class' is specified or
360/// not. ``None`` means 1 unless in a
361/// :obj:`joblib.parallel_backend` context. ``-1`` means using all
362/// processors. See :term:`Glossary <n_jobs>` for more details.
363///
364/// l1_ratio : float, default=None
365/// The Elastic-Net mixing parameter, with ``0 <= l1_ratio <= 1``. Only
366/// used if ``penalty='elasticnet'``. Setting ``l1_ratio=0`` is equivalent
367/// to using ``penalty='l2'``, while setting ``l1_ratio=1`` is equivalent
368/// to using ``penalty='l1'``. For ``0 < l1_ratio <1``, the penalty is a
369/// combination of L1 and L2.
370///
371/// Attributes
372/// ----------
373/// classes_ : ndarray of shape (n_classes, )
374/// A list of class labels known to the classifier.
375///
376/// coef_ : ndarray of shape (1, n_features) or (n_classes, n_features)
377/// Coefficient of the features in the decision function.
378///
379/// `coef_` is of shape (1, n_features) when the given problem is binary.
380/// In particular, when `multi_class='multinomial'`, `coef_` corresponds
381/// to outcome 1 (True) and `-coef_` corresponds to outcome 0 (False).
382///
383/// intercept_ : ndarray of shape (1,) or (n_classes,)
384/// Intercept (a.k.a. bias) added to the decision function.
385///
386/// If `fit_intercept` is set to False, the intercept is set to zero.
387/// `intercept_` is of shape (1,) when the given problem is binary.
388/// In particular, when `multi_class='multinomial'`, `intercept_`
389/// corresponds to outcome 1 (True) and `-intercept_` corresponds to
390/// outcome 0 (False).
391///
392/// n_features_in_ : int
393/// Number of features seen during :term:`fit`.
394///
395/// n_iter_ : ndarray of shape (n_classes,) or (1, )
396/// Actual number of iterations for all classes. If binary or multinomial,
397/// it returns only 1 element. For liblinear solver, only the maximum
398/// number of iteration across all classes is given.
399///
400/// Examples
401/// --------
402/// >>> from sklears_python import LogisticRegression
403/// >>> from sklearn.datasets import load_iris
404/// >>> X, y = load_iris(return_X_y=True)
405/// >>> clf = LogisticRegression(random_state=0).fit(X, y)
406/// >>> clf.predict(X[:2, :])
407/// array([0, 0])
408/// >>> clf.predict_proba(X[:2, :])
409/// array([[9.8...e-01, 1.8...e-02, 1.4...e-08],
410/// [9.7...e-01, 2.8...e-02, ...e-08]])
411/// >>> clf.score(X, y)
412/// 0.97...
413///
414/// Notes
415/// -----
416/// The underlying C implementation uses a random number generator to
417/// select features when fitting the model. It is thus not uncommon,
418/// to have slightly different results for the same input data. If
419/// that happens, try with a smaller tol parameter.
420///
421/// Predict output may not match that of standalone liblinear in certain
422/// cases. See :ref:`differences from liblinear <liblinear_differences>`
423/// in the narrative documentation.
424///
425/// References
426/// ----------
427/// L-BFGS-B -- Software for Large-scale Bound-constrained Optimization
428/// Ciyou Zhu, Richard Byrd, Jorge Nocedal and Jose Luis Morales.
429/// http://users.iems.northwestern.edu/~nocedal/lbfgsb.html
430///
431/// LIBLINEAR -- A Library for Large Linear Classification
432/// https://www.csie.ntu.edu.tw/~cjlin/liblinear/
433///
434/// SAG -- Mark Schmidt, Nicolas Le Roux, and Francis Bach
435/// Minimizing Finite Sums with the Stochastic Average Gradient
436/// https://hal.inria.fr/hal-00860051/document
437///
438/// SAGA -- Defazio, A., Bach F. & Lacoste-Julien S. (2014).
439/// SAGA: A Fast Incremental Gradient Method With Support
440/// for Non-Strongly Convex Composite Objectives
441/// https://arxiv.org/abs/1407.0202
442///
443/// Hsiang-Fu Yu, Fang-Lan Huang, Chih-Jen Lin (2011). Dual coordinate descent
444/// methods for logistic regression and maximum entropy models.
445/// Machine Learning 85(1-2):41-75.
446/// https://www.csie.ntu.edu.tw/~cjlin/papers/maxent_dual.pdf
447#[pyclass(name = "LogisticRegression")]
448pub struct PyLogisticRegression {
449 py_config: PyLogisticRegressionConfig,
450 fitted_model: Option<BasicLogisticRegression>,
451}
452
453#[pymethods]
454impl PyLogisticRegression {
455 #[new]
456 #[pyo3(signature = (penalty="l2", dual=false, tol=1e-4, c=1.0, fit_intercept=true, intercept_scaling=1.0, class_weight=None, random_state=None, solver="lbfgs", max_iter=100, multi_class="auto", verbose=0, warm_start=false, n_jobs=None, l1_ratio=None))]
457 fn new(
458 penalty: &str,
459 dual: bool,
460 tol: f64,
461 c: f64,
462 fit_intercept: bool,
463 intercept_scaling: f64,
464 class_weight: Option<&str>,
465 random_state: Option<i32>,
466 solver: &str,
467 max_iter: usize,
468 multi_class: &str,
469 verbose: i32,
470 warm_start: bool,
471 n_jobs: Option<i32>,
472 l1_ratio: Option<f64>,
473 ) -> Self {
474 // Note: Some parameters are sklearn-specific and don't directly map to our implementation
475 let _dual = dual;
476 let _intercept_scaling = intercept_scaling;
477 let _verbose = verbose;
478
479 let py_config = PyLogisticRegressionConfig {
480 penalty: penalty.to_string(),
481 c,
482 fit_intercept,
483 max_iter,
484 tol,
485 solver: solver.to_string(),
486 random_state,
487 class_weight: class_weight.map(|s| s.to_string()),
488 multi_class: multi_class.to_string(),
489 warm_start,
490 n_jobs,
491 l1_ratio,
492 };
493
494 Self {
495 py_config,
496 fitted_model: None,
497 }
498 }
499
500 /// Fit the logistic regression model
501 fn fit(&mut self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<()> {
502 let x_array = x.as_array().to_owned();
503 let y_array = y.as_array().to_owned();
504
505 // Validate input arrays using enhanced validation
506 validate_fit_arrays_enhanced(&x_array, &y_array).map_err(PyErr::from)?;
507
508 // Create and fit model
509 let mut model = BasicLogisticRegression::new(self.py_config.clone());
510 match model.fit(&x_array, &y_array) {
511 Ok(()) => {
512 self.fitted_model = Some(model);
513 Ok(())
514 }
515 Err(e) => Err(PyErr::from(e)),
516 }
517 }
518
519 /// Predict class labels for samples
520 fn predict(&self, x: PyReadonlyArray2<f64>) -> PyResult<Py<PyArray1<f64>>> {
521 let fitted = self
522 .fitted_model
523 .as_ref()
524 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
525
526 let x_array = x.as_array().to_owned();
527 validate_predict_array(&x_array)?;
528
529 match fitted.predict(&x_array) {
530 Ok(predictions) => {
531 let py = unsafe { Python::assume_attached() };
532 Ok(predictions.into_pyarray(py).into())
533 }
534 Err(e) => Err(PyErr::from(e)),
535 }
536 }
537
538 /// Predict class probabilities for samples
539 fn predict_proba(&self, x: PyReadonlyArray2<f64>) -> PyResult<Py<PyArray2<f64>>> {
540 let fitted = self
541 .fitted_model
542 .as_ref()
543 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
544
545 let x_array = x.as_array().to_owned();
546 validate_predict_array(&x_array)?;
547
548 match fitted.predict_proba(&x_array) {
549 Ok(probabilities) => {
550 let py = unsafe { Python::assume_attached() };
551 Ok(probabilities.into_pyarray(py).into())
552 }
553 Err(e) => Err(PyErr::from(e)),
554 }
555 }
556
557 /// Get model coefficients
558 #[getter]
559 fn coef_(&self) -> PyResult<Py<PyArray1<f64>>> {
560 let fitted = self
561 .fitted_model
562 .as_ref()
563 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
564
565 let py = unsafe { Python::assume_attached() };
566 Ok(fitted.coef_.clone().into_pyarray(py).into())
567 }
568
569 /// Get model intercept
570 #[getter]
571 fn intercept_(&self) -> PyResult<f64> {
572 let fitted = self
573 .fitted_model
574 .as_ref()
575 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
576
577 Ok(fitted.intercept_)
578 }
579
580 /// Get unique class labels
581 #[getter]
582 fn classes_(&self) -> PyResult<Py<PyArray1<f64>>> {
583 let fitted = self
584 .fitted_model
585 .as_ref()
586 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
587
588 let py = unsafe { Python::assume_attached() };
589 Ok(fitted.classes_.clone().into_pyarray(py).into())
590 }
591
592 /// Calculate accuracy score
593 fn score(&self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<f64> {
594 let fitted = self
595 .fitted_model
596 .as_ref()
597 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
598
599 let x_array = x.as_array().to_owned();
600 let y_array = y.as_array().to_owned();
601
602 match fitted.score(&x_array, &y_array) {
603 Ok(score) => Ok(score),
604 Err(e) => Err(PyErr::from(e)),
605 }
606 }
607
608 /// Get number of features
609 #[getter]
610 fn n_features_in_(&self) -> PyResult<usize> {
611 let fitted = self
612 .fitted_model
613 .as_ref()
614 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
615
616 Ok(fitted.n_features_)
617 }
618
619 /// Return parameters for this estimator (sklearn compatibility)
620 fn get_params(&self, deep: Option<bool>) -> PyResult<Py<PyDict>> {
621 let _deep = deep.unwrap_or(true);
622
623 let py = unsafe { Python::assume_attached() };
624 let dict = PyDict::new(py);
625
626 dict.set_item("penalty", &self.py_config.penalty)?;
627 dict.set_item("C", self.py_config.c)?;
628 dict.set_item("fit_intercept", self.py_config.fit_intercept)?;
629 dict.set_item("max_iter", self.py_config.max_iter)?;
630 dict.set_item("tol", self.py_config.tol)?;
631 dict.set_item("solver", &self.py_config.solver)?;
632 dict.set_item("random_state", self.py_config.random_state)?;
633 dict.set_item("class_weight", &self.py_config.class_weight)?;
634 dict.set_item("multi_class", &self.py_config.multi_class)?;
635 dict.set_item("warm_start", self.py_config.warm_start)?;
636 dict.set_item("n_jobs", self.py_config.n_jobs)?;
637 dict.set_item("l1_ratio", self.py_config.l1_ratio)?;
638
639 Ok(dict.into())
640 }
641
642 /// Set parameters for this estimator (sklearn compatibility)
643 fn set_params(&mut self, kwargs: &Bound<'_, PyDict>) -> PyResult<()> {
644 // Update configuration parameters
645 if let Some(penalty) = kwargs.get_item("penalty")? {
646 let penalty_str: String = penalty.extract()?;
647 self.py_config.penalty = penalty_str;
648 }
649 if let Some(c) = kwargs.get_item("C")? {
650 self.py_config.c = c.extract()?;
651 }
652 if let Some(fit_intercept) = kwargs.get_item("fit_intercept")? {
653 self.py_config.fit_intercept = fit_intercept.extract()?;
654 }
655 if let Some(max_iter) = kwargs.get_item("max_iter")? {
656 self.py_config.max_iter = max_iter.extract()?;
657 }
658 if let Some(tol) = kwargs.get_item("tol")? {
659 self.py_config.tol = tol.extract()?;
660 }
661 if let Some(solver) = kwargs.get_item("solver")? {
662 let solver_str: String = solver.extract()?;
663 self.py_config.solver = solver_str;
664 }
665 if let Some(random_state) = kwargs.get_item("random_state")? {
666 self.py_config.random_state = random_state.extract()?;
667 }
668 if let Some(class_weight) = kwargs.get_item("class_weight")? {
669 let weight_str: Option<String> = class_weight.extract()?;
670 self.py_config.class_weight = weight_str;
671 }
672 if let Some(multi_class) = kwargs.get_item("multi_class")? {
673 let multi_class_str: String = multi_class.extract()?;
674 self.py_config.multi_class = multi_class_str;
675 }
676 if let Some(warm_start) = kwargs.get_item("warm_start")? {
677 self.py_config.warm_start = warm_start.extract()?;
678 }
679 if let Some(n_jobs) = kwargs.get_item("n_jobs")? {
680 self.py_config.n_jobs = n_jobs.extract()?;
681 }
682 if let Some(l1_ratio) = kwargs.get_item("l1_ratio")? {
683 self.py_config.l1_ratio = l1_ratio.extract()?;
684 }
685
686 // Clear fitted model since config changed
687 self.fitted_model = None;
688
689 Ok(())
690 }
691
692 /// String representation
693 fn __repr__(&self) -> String {
694 format!(
695 "LogisticRegression(penalty='{}', C={}, fit_intercept={}, max_iter={}, tol={}, solver='{}', random_state={:?}, multi_class='{}')",
696 self.py_config.penalty,
697 self.py_config.c,
698 self.py_config.fit_intercept,
699 self.py_config.max_iter,
700 self.py_config.tol,
701 self.py_config.solver,
702 self.py_config.random_state,
703 self.py_config.multi_class
704 )
705 }
706}