sklears_python/linear/
logistic_regression.rs

1//! Python bindings for Logistic Regression
2//!
3//! This module provides Python bindings for Logistic Regression,
4//! offering scikit-learn compatible interfaces for binary and multiclass classification.
5//!
6//! Note: This is a basic implementation using manual logistic regression until
7//! the sklears-linear LogisticRegression feature compilation issues are resolved.
8
9use super::common::*;
10use numpy::IntoPyArray;
11use pyo3::types::PyDict;
12use pyo3::Bound;
13use scirs2_autograd::ndarray::{s, Array1, Array2, Axis};
14use scirs2_core::random::{thread_rng, Rng};
15
16/// Python-specific configuration wrapper for LogisticRegression
17#[derive(Debug, Clone)]
18pub struct PyLogisticRegressionConfig {
19    pub penalty: String,
20    pub c: f64,
21    pub fit_intercept: bool,
22    pub max_iter: usize,
23    pub tol: f64,
24    pub solver: String,
25    pub random_state: Option<i32>,
26    pub class_weight: Option<String>,
27    pub multi_class: String,
28    pub warm_start: bool,
29    pub n_jobs: Option<i32>,
30    pub l1_ratio: Option<f64>,
31}
32
33impl Default for PyLogisticRegressionConfig {
34    fn default() -> Self {
35        Self {
36            penalty: "l2".to_string(),
37            c: 1.0,
38            fit_intercept: true,
39            max_iter: 100,
40            tol: 1e-4,
41            solver: "lbfgs".to_string(),
42            random_state: None,
43            class_weight: None,
44            multi_class: "auto".to_string(),
45            warm_start: false,
46            n_jobs: None,
47            l1_ratio: None,
48        }
49    }
50}
51
52/// Basic logistic regression implementation
53#[derive(Debug, Clone)]
54struct BasicLogisticRegression {
55    config: PyLogisticRegressionConfig,
56    coef_: Array1<f64>,
57    intercept_: f64,
58    classes_: Array1<f64>,
59    n_features_: usize,
60}
61
62impl BasicLogisticRegression {
63    fn new(config: PyLogisticRegressionConfig) -> Self {
64        Self {
65            config,
66            coef_: Array1::zeros(1),
67            intercept_: 0.0,
68            classes_: Array1::zeros(1),
69            n_features_: 0,
70        }
71    }
72
73    fn sigmoid(z: f64) -> f64 {
74        1.0 / (1.0 + (-z).exp())
75    }
76
77    fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<(), SklearsPythonError> {
78        let n_samples = x.nrows();
79        let n_features = x.ncols();
80        self.n_features_ = n_features;
81
82        if n_samples != y.len() {
83            return Err(SklearsPythonError::ValidationError(
84                "X and y have incompatible shapes".to_string(),
85            ));
86        }
87
88        // Find unique classes
89        let mut classes: Vec<f64> = y.iter().cloned().collect();
90        classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
91        classes.dedup();
92        self.classes_ = Array1::from_vec(classes.clone());
93
94        // For now, only support binary classification
95        if classes.len() != 2 {
96            return Err(SklearsPythonError::ValidationError(
97                "Currently only binary classification is supported".to_string(),
98            ));
99        }
100
101        // Map classes to 0 and 1
102        let y_mapped: Array1<f64> = y.mapv(|val| if val == classes[0] { 0.0 } else { 1.0 });
103
104        // Add intercept column if needed
105        let x_design = if self.config.fit_intercept {
106            let mut x_new = Array2::ones((n_samples, n_features + 1));
107            x_new.slice_mut(s![.., 1..]).assign(x);
108            x_new
109        } else {
110            x.clone()
111        };
112
113        let n_params = if self.config.fit_intercept {
114            n_features + 1
115        } else {
116            n_features
117        };
118
119        // Initialize weights
120        let mut rng = thread_rng();
121        let mut weights = Array1::from_shape_fn(n_params, |_| rng.gen::<f64>() * 0.01);
122
123        // Gradient descent
124        let learning_rate = 0.01;
125        for _iter in 0..self.config.max_iter {
126            let mut total_loss = 0.0;
127            let mut gradient: Array1<f64> = Array1::zeros(n_params);
128
129            for i in 0..n_samples {
130                let xi = x_design.row(i);
131                let yi = y_mapped[i];
132
133                let z = xi.dot(&weights);
134                let prediction = Self::sigmoid(z);
135
136                // Log loss contribution
137                let loss = if yi == 1.0 {
138                    -prediction.ln()
139                } else {
140                    -(1.0 - prediction).ln()
141                };
142                total_loss += loss;
143
144                // Gradient contribution
145                let error = prediction - yi;
146                for j in 0..n_params {
147                    gradient[j] += error * xi[j];
148                }
149            }
150
151            // Apply L2 regularization if configured
152            if self.config.penalty == "l2" && self.config.c > 0.0 {
153                let reg_strength = 1.0 / self.config.c;
154                for j in 0..n_params {
155                    // Don't regularize intercept
156                    if !self.config.fit_intercept || j > 0 {
157                        gradient[j] += reg_strength * weights[j];
158                        total_loss += 0.5 * reg_strength * weights[j] * weights[j];
159                    }
160                }
161            }
162
163            // Update weights
164            for j in 0..n_params {
165                weights[j] -= learning_rate * gradient[j] / n_samples as f64;
166            }
167
168            // Check convergence
169            let avg_loss = total_loss / n_samples as f64;
170            if avg_loss < self.config.tol {
171                break;
172            }
173        }
174
175        // Extract coefficients and intercept
176        if self.config.fit_intercept {
177            self.intercept_ = weights[0];
178            self.coef_ = weights.slice(s![1..]).to_owned();
179        } else {
180            self.intercept_ = 0.0;
181            self.coef_ = weights;
182        }
183
184        Ok(())
185    }
186
187    fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>, SklearsPythonError> {
188        if x.ncols() != self.n_features_ {
189            return Err(SklearsPythonError::ValidationError(format!(
190                "X has {} features, but model expects {} features",
191                x.ncols(),
192                self.n_features_
193            )));
194        }
195
196        let probabilities = self.predict_proba(x)?;
197        let predictions = probabilities
198            .axis_iter(Axis(0))
199            .map(|row| {
200                if row[1] > 0.5 {
201                    self.classes_[1]
202                } else {
203                    self.classes_[0]
204                }
205            })
206            .collect::<Vec<f64>>();
207
208        Ok(Array1::from_vec(predictions))
209    }
210
211    fn predict_proba(&self, x: &Array2<f64>) -> Result<Array2<f64>, SklearsPythonError> {
212        if x.ncols() != self.n_features_ {
213            return Err(SklearsPythonError::ValidationError(format!(
214                "X has {} features, but model expects {} features",
215                x.ncols(),
216                self.n_features_
217            )));
218        }
219
220        let n_samples = x.nrows();
221        let mut probabilities = Array2::zeros((n_samples, 2));
222
223        for i in 0..n_samples {
224            let xi = x.row(i);
225            let z = xi.dot(&self.coef_) + self.intercept_;
226            let prob_class_1 = Self::sigmoid(z);
227            let prob_class_0 = 1.0 - prob_class_1;
228
229            probabilities[[i, 0]] = prob_class_0;
230            probabilities[[i, 1]] = prob_class_1;
231        }
232
233        Ok(probabilities)
234    }
235
236    fn score(&self, x: &Array2<f64>, y: &Array1<f64>) -> Result<f64, SklearsPythonError> {
237        let predictions = self.predict(x)?;
238        let correct = y
239            .iter()
240            .zip(predictions.iter())
241            .filter(|(&true_val, &pred_val)| (true_val - pred_val).abs() < 1e-6)
242            .count();
243
244        Ok(correct as f64 / y.len() as f64)
245    }
246}
247
248/// Logistic Regression (aka logit, MaxEnt) classifier.
249///
250/// In the multiclass case, the training algorithm uses the one-vs-rest (OvR)
251/// scheme if the 'multi_class' option is set to 'ovr', and uses the
252/// cross-entropy loss if the 'multi_class' option is set to 'multinomial'.
253/// (Currently the 'multinomial' option is supported only by the 'lbfgs',
254/// 'sag', 'saga' and 'newton-cg' solvers.)
255///
256/// This class implements regularized logistic regression using the
257/// 'liblinear' library, 'newton-cg', 'sag', 'saga' and 'lbfgs' solvers.
258/// **Note that regularization is applied by default**. It can handle both
259/// dense and sparse input. Use C-ordered arrays or CSR matrices containing
260/// 64-bit floats for optimal performance; any other input format will be
261/// converted (and copied).
262///
263/// The 'newton-cg', 'sag', and 'lbfgs' solvers support only L2 regularization
264/// with primal formulation, or no regularization. The 'liblinear' solver
265/// supports both L1 and L2 regularization, with a dual formulation only for
266/// the L2 penalty. The Elastic-Net regularization is only supported by the
267/// 'saga' solver.
268///
269/// Parameters
270/// ----------
271/// penalty : {'l1', 'l2', 'elasticnet', None}, default='l2'
272///     Specify the norm of the penalty:
273///
274///     - None: no penalty is added;
275///     - 'l2': add a L2 penalty term and it is the default choice;
276///     - 'l1': add a L1 penalty term;
277///     - 'elasticnet': both L1 and L2 penalty terms are added.
278///
279/// dual : bool, default=False
280///     Dual or primal formulation. Dual formulation is only implemented for
281///     l2 penalty with liblinear solver. Prefer dual=False when
282///     n_samples > n_features.
283///
284/// tol : float, default=1e-4
285///     Tolerance for stopping criteria.
286///
287/// C : float, default=1.0
288///     Inverse of regularization strength; must be a positive float.
289///     Like in support vector machines, smaller values specify stronger
290///     regularization.
291///
292/// fit_intercept : bool, default=True
293///     Specifies if a constant (a.k.a. bias or intercept) should be
294///     added to the decision function.
295///
296/// intercept_scaling : float, default=1
297///     Useful only when the solver 'liblinear' is used
298///     and self.fit_intercept is set to True. In this case, x becomes
299///     [x, self.intercept_scaling],
300///     i.e. a "synthetic" feature with constant value equal to
301///     intercept_scaling is appended to the instance vector.
302///     The intercept becomes intercept_scaling * synthetic_feature_weight.
303///
304///     Note! the synthetic feature weight is subject to l1/l2 regularization
305///     as all other features.
306///     To lessen the effect of regularization on synthetic feature weight
307///     (and therefore on the intercept) intercept_scaling has to be increased.
308///
309/// class_weight : dict or 'balanced', default=None
310///     Weights associated with classes in the form ``{class_label: weight}``.
311///     If not given, all classes are supposed to have weight one.
312///
313///     The "balanced" mode uses the values of y to automatically adjust
314///     weights inversely proportional to class frequencies in the input data
315///     as ``n_samples / (n_classes * np.bincount(y))``.
316///
317///     Note that these weights will be multiplied with sample_weight (passed
318///     through the fit method) if sample_weight is specified.
319///
320/// random_state : int, RandomState instance, default=None
321///     Used when ``solver`` == 'sag', 'saga' or 'liblinear' to shuffle the
322///     data. See :term:`Glossary <random_state>` for details.
323///
324/// solver : {'lbfgs', 'liblinear', 'newton-cg', 'newton-cholesky', 'sag', 'saga'}, \
325///         default='lbfgs'
326///
327///     Algorithm to use in the optimization problem. Default is 'lbfgs'.
328///     To choose a solver, you might want to consider the following aspects:
329///
330///         - For small datasets, 'liblinear' is a good choice, whereas 'sag'
331///           and 'saga' are faster for large ones;
332///         - For multiclass problems, only 'newton-cg', 'sag', 'saga' and
333///           'lbfgs' handle multinomial loss;
334///         - 'liblinear' is limited to one-versus-rest schemes.
335///
336/// max_iter : int, default=100
337///     Maximum number of iterations taken for the solvers to converge.
338///
339/// multi_class : {'auto', 'ovr', 'multinomial'}, default='auto'
340///     If the option chosen is 'ovr', then a binary problem is fit for each
341///     label. For 'multinomial' the loss minimised is the multinomial loss fit
342///     across the entire probability distribution, *even when the data is
343///     binary*. 'multinomial' is unavailable when solver='liblinear'.
344///     'auto' selects 'ovr' if the data is binary, or if solver='liblinear',
345///     and otherwise selects 'multinomial'.
346///
347/// verbose : int, default=0
348///     For the liblinear and lbfgs solvers set verbose to any positive
349///     number for verbosity.
350///
351/// warm_start : bool, default=False
352///     When set to True, reuse the solution of the previous call to fit as
353///     initialization, otherwise, just erase the previous solution.
354///     Useless for liblinear solver. See :term:`the Glossary <warm_start>`.
355///
356/// n_jobs : int, default=None
357///     Number of CPU cores used when parallelizing over classes if
358///     multi_class='ovr'". This parameter is ignored when the ``solver``
359///     is set to 'liblinear' regardless of whether 'multi_class' is specified or
360///     not. ``None`` means 1 unless in a
361///     :obj:`joblib.parallel_backend` context. ``-1`` means using all
362///     processors. See :term:`Glossary <n_jobs>` for more details.
363///
364/// l1_ratio : float, default=None
365///     The Elastic-Net mixing parameter, with ``0 <= l1_ratio <= 1``. Only
366///     used if ``penalty='elasticnet'``. Setting ``l1_ratio=0`` is equivalent
367///     to using ``penalty='l2'``, while setting ``l1_ratio=1`` is equivalent
368///     to using ``penalty='l1'``. For ``0 < l1_ratio <1``, the penalty is a
369///     combination of L1 and L2.
370///
371/// Attributes
372/// ----------
373/// classes_ : ndarray of shape (n_classes, )
374///     A list of class labels known to the classifier.
375///
376/// coef_ : ndarray of shape (1, n_features) or (n_classes, n_features)
377///     Coefficient of the features in the decision function.
378///
379///     `coef_` is of shape (1, n_features) when the given problem is binary.
380///     In particular, when `multi_class='multinomial'`, `coef_` corresponds
381///     to outcome 1 (True) and `-coef_` corresponds to outcome 0 (False).
382///
383/// intercept_ : ndarray of shape (1,) or (n_classes,)
384///     Intercept (a.k.a. bias) added to the decision function.
385///
386///     If `fit_intercept` is set to False, the intercept is set to zero.
387///     `intercept_` is of shape (1,) when the given problem is binary.
388///     In particular, when `multi_class='multinomial'`, `intercept_`
389///     corresponds to outcome 1 (True) and `-intercept_` corresponds to
390///     outcome 0 (False).
391///
392/// n_features_in_ : int
393///     Number of features seen during :term:`fit`.
394///
395/// n_iter_ : ndarray of shape (n_classes,) or (1, )
396///     Actual number of iterations for all classes. If binary or multinomial,
397///     it returns only 1 element. For liblinear solver, only the maximum
398///     number of iteration across all classes is given.
399///
400/// Examples
401/// --------
402/// >>> from sklears_python import LogisticRegression
403/// >>> from sklearn.datasets import load_iris
404/// >>> X, y = load_iris(return_X_y=True)
405/// >>> clf = LogisticRegression(random_state=0).fit(X, y)
406/// >>> clf.predict(X[:2, :])
407/// array([0, 0])
408/// >>> clf.predict_proba(X[:2, :])
409/// array([[9.8...e-01, 1.8...e-02, 1.4...e-08],
410///        [9.7...e-01, 2.8...e-02, ...e-08]])
411/// >>> clf.score(X, y)
412/// 0.97...
413///
414/// Notes
415/// -----
416/// The underlying C implementation uses a random number generator to
417/// select features when fitting the model. It is thus not uncommon,
418/// to have slightly different results for the same input data. If
419/// that happens, try with a smaller tol parameter.
420///
421/// Predict output may not match that of standalone liblinear in certain
422/// cases. See :ref:`differences from liblinear <liblinear_differences>`
423/// in the narrative documentation.
424///
425/// References
426/// ----------
427/// L-BFGS-B -- Software for Large-scale Bound-constrained Optimization
428/// Ciyou Zhu, Richard Byrd, Jorge Nocedal and Jose Luis Morales.
429/// http://users.iems.northwestern.edu/~nocedal/lbfgsb.html
430///
431/// LIBLINEAR -- A Library for Large Linear Classification
432/// https://www.csie.ntu.edu.tw/~cjlin/liblinear/
433///
434/// SAG -- Mark Schmidt, Nicolas Le Roux, and Francis Bach
435/// Minimizing Finite Sums with the Stochastic Average Gradient
436/// https://hal.inria.fr/hal-00860051/document
437///
438/// SAGA -- Defazio, A., Bach F. & Lacoste-Julien S. (2014).
439/// SAGA: A Fast Incremental Gradient Method With Support
440/// for Non-Strongly Convex Composite Objectives
441/// https://arxiv.org/abs/1407.0202
442///
443/// Hsiang-Fu Yu, Fang-Lan Huang, Chih-Jen Lin (2011). Dual coordinate descent
444/// methods for logistic regression and maximum entropy models.
445/// Machine Learning 85(1-2):41-75.
446/// https://www.csie.ntu.edu.tw/~cjlin/papers/maxent_dual.pdf
447#[pyclass(name = "LogisticRegression")]
448pub struct PyLogisticRegression {
449    py_config: PyLogisticRegressionConfig,
450    fitted_model: Option<BasicLogisticRegression>,
451}
452
453#[pymethods]
454impl PyLogisticRegression {
455    #[new]
456    #[pyo3(signature = (penalty="l2", dual=false, tol=1e-4, c=1.0, fit_intercept=true, intercept_scaling=1.0, class_weight=None, random_state=None, solver="lbfgs", max_iter=100, multi_class="auto", verbose=0, warm_start=false, n_jobs=None, l1_ratio=None))]
457    fn new(
458        penalty: &str,
459        dual: bool,
460        tol: f64,
461        c: f64,
462        fit_intercept: bool,
463        intercept_scaling: f64,
464        class_weight: Option<&str>,
465        random_state: Option<i32>,
466        solver: &str,
467        max_iter: usize,
468        multi_class: &str,
469        verbose: i32,
470        warm_start: bool,
471        n_jobs: Option<i32>,
472        l1_ratio: Option<f64>,
473    ) -> Self {
474        // Note: Some parameters are sklearn-specific and don't directly map to our implementation
475        let _dual = dual;
476        let _intercept_scaling = intercept_scaling;
477        let _verbose = verbose;
478
479        let py_config = PyLogisticRegressionConfig {
480            penalty: penalty.to_string(),
481            c,
482            fit_intercept,
483            max_iter,
484            tol,
485            solver: solver.to_string(),
486            random_state,
487            class_weight: class_weight.map(|s| s.to_string()),
488            multi_class: multi_class.to_string(),
489            warm_start,
490            n_jobs,
491            l1_ratio,
492        };
493
494        Self {
495            py_config,
496            fitted_model: None,
497        }
498    }
499
500    /// Fit the logistic regression model
501    fn fit(&mut self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<()> {
502        let x_array = x.as_array().to_owned();
503        let y_array = y.as_array().to_owned();
504
505        // Validate input arrays using enhanced validation
506        validate_fit_arrays_enhanced(&x_array, &y_array).map_err(PyErr::from)?;
507
508        // Create and fit model
509        let mut model = BasicLogisticRegression::new(self.py_config.clone());
510        match model.fit(&x_array, &y_array) {
511            Ok(()) => {
512                self.fitted_model = Some(model);
513                Ok(())
514            }
515            Err(e) => Err(PyErr::from(e)),
516        }
517    }
518
519    /// Predict class labels for samples
520    fn predict(&self, x: PyReadonlyArray2<f64>) -> PyResult<Py<PyArray1<f64>>> {
521        let fitted = self
522            .fitted_model
523            .as_ref()
524            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
525
526        let x_array = x.as_array().to_owned();
527        validate_predict_array(&x_array)?;
528
529        match fitted.predict(&x_array) {
530            Ok(predictions) => {
531                let py = unsafe { Python::assume_attached() };
532                Ok(predictions.into_pyarray(py).into())
533            }
534            Err(e) => Err(PyErr::from(e)),
535        }
536    }
537
538    /// Predict class probabilities for samples
539    fn predict_proba(&self, x: PyReadonlyArray2<f64>) -> PyResult<Py<PyArray2<f64>>> {
540        let fitted = self
541            .fitted_model
542            .as_ref()
543            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
544
545        let x_array = x.as_array().to_owned();
546        validate_predict_array(&x_array)?;
547
548        match fitted.predict_proba(&x_array) {
549            Ok(probabilities) => {
550                let py = unsafe { Python::assume_attached() };
551                Ok(probabilities.into_pyarray(py).into())
552            }
553            Err(e) => Err(PyErr::from(e)),
554        }
555    }
556
557    /// Get model coefficients
558    #[getter]
559    fn coef_(&self) -> PyResult<Py<PyArray1<f64>>> {
560        let fitted = self
561            .fitted_model
562            .as_ref()
563            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
564
565        let py = unsafe { Python::assume_attached() };
566        Ok(fitted.coef_.clone().into_pyarray(py).into())
567    }
568
569    /// Get model intercept
570    #[getter]
571    fn intercept_(&self) -> PyResult<f64> {
572        let fitted = self
573            .fitted_model
574            .as_ref()
575            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
576
577        Ok(fitted.intercept_)
578    }
579
580    /// Get unique class labels
581    #[getter]
582    fn classes_(&self) -> PyResult<Py<PyArray1<f64>>> {
583        let fitted = self
584            .fitted_model
585            .as_ref()
586            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
587
588        let py = unsafe { Python::assume_attached() };
589        Ok(fitted.classes_.clone().into_pyarray(py).into())
590    }
591
592    /// Calculate accuracy score
593    fn score(&self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<f64> {
594        let fitted = self
595            .fitted_model
596            .as_ref()
597            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
598
599        let x_array = x.as_array().to_owned();
600        let y_array = y.as_array().to_owned();
601
602        match fitted.score(&x_array, &y_array) {
603            Ok(score) => Ok(score),
604            Err(e) => Err(PyErr::from(e)),
605        }
606    }
607
608    /// Get number of features
609    #[getter]
610    fn n_features_in_(&self) -> PyResult<usize> {
611        let fitted = self
612            .fitted_model
613            .as_ref()
614            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
615
616        Ok(fitted.n_features_)
617    }
618
619    /// Return parameters for this estimator (sklearn compatibility)
620    fn get_params(&self, deep: Option<bool>) -> PyResult<Py<PyDict>> {
621        let _deep = deep.unwrap_or(true);
622
623        let py = unsafe { Python::assume_attached() };
624        let dict = PyDict::new(py);
625
626        dict.set_item("penalty", &self.py_config.penalty)?;
627        dict.set_item("C", self.py_config.c)?;
628        dict.set_item("fit_intercept", self.py_config.fit_intercept)?;
629        dict.set_item("max_iter", self.py_config.max_iter)?;
630        dict.set_item("tol", self.py_config.tol)?;
631        dict.set_item("solver", &self.py_config.solver)?;
632        dict.set_item("random_state", self.py_config.random_state)?;
633        dict.set_item("class_weight", &self.py_config.class_weight)?;
634        dict.set_item("multi_class", &self.py_config.multi_class)?;
635        dict.set_item("warm_start", self.py_config.warm_start)?;
636        dict.set_item("n_jobs", self.py_config.n_jobs)?;
637        dict.set_item("l1_ratio", self.py_config.l1_ratio)?;
638
639        Ok(dict.into())
640    }
641
642    /// Set parameters for this estimator (sklearn compatibility)
643    fn set_params(&mut self, kwargs: &Bound<'_, PyDict>) -> PyResult<()> {
644        // Update configuration parameters
645        if let Some(penalty) = kwargs.get_item("penalty")? {
646            let penalty_str: String = penalty.extract()?;
647            self.py_config.penalty = penalty_str;
648        }
649        if let Some(c) = kwargs.get_item("C")? {
650            self.py_config.c = c.extract()?;
651        }
652        if let Some(fit_intercept) = kwargs.get_item("fit_intercept")? {
653            self.py_config.fit_intercept = fit_intercept.extract()?;
654        }
655        if let Some(max_iter) = kwargs.get_item("max_iter")? {
656            self.py_config.max_iter = max_iter.extract()?;
657        }
658        if let Some(tol) = kwargs.get_item("tol")? {
659            self.py_config.tol = tol.extract()?;
660        }
661        if let Some(solver) = kwargs.get_item("solver")? {
662            let solver_str: String = solver.extract()?;
663            self.py_config.solver = solver_str;
664        }
665        if let Some(random_state) = kwargs.get_item("random_state")? {
666            self.py_config.random_state = random_state.extract()?;
667        }
668        if let Some(class_weight) = kwargs.get_item("class_weight")? {
669            let weight_str: Option<String> = class_weight.extract()?;
670            self.py_config.class_weight = weight_str;
671        }
672        if let Some(multi_class) = kwargs.get_item("multi_class")? {
673            let multi_class_str: String = multi_class.extract()?;
674            self.py_config.multi_class = multi_class_str;
675        }
676        if let Some(warm_start) = kwargs.get_item("warm_start")? {
677            self.py_config.warm_start = warm_start.extract()?;
678        }
679        if let Some(n_jobs) = kwargs.get_item("n_jobs")? {
680            self.py_config.n_jobs = n_jobs.extract()?;
681        }
682        if let Some(l1_ratio) = kwargs.get_item("l1_ratio")? {
683            self.py_config.l1_ratio = l1_ratio.extract()?;
684        }
685
686        // Clear fitted model since config changed
687        self.fitted_model = None;
688
689        Ok(())
690    }
691
692    /// String representation
693    fn __repr__(&self) -> String {
694        format!(
695            "LogisticRegression(penalty='{}', C={}, fit_intercept={}, max_iter={}, tol={}, solver='{}', random_state={:?}, multi_class='{}')",
696            self.py_config.penalty,
697            self.py_config.c,
698            self.py_config.fit_intercept,
699            self.py_config.max_iter,
700            self.py_config.tol,
701            self.py_config.solver,
702            self.py_config.random_state,
703            self.py_config.multi_class
704        )
705    }
706}