1use super::common::*;
8use pyo3::types::PyDict;
9use pyo3::Bound;
10use sklears_core::traits::{Fit, Predict, PredictProba, Score, Trained};
11use sklears_linear::{LogisticRegression, LogisticRegressionConfig, Penalty, Solver};
12
13#[derive(Debug, Clone)]
15pub struct PyLogisticRegressionConfig {
16 pub penalty: String,
17 pub c: f64,
18 pub fit_intercept: bool,
19 pub max_iter: usize,
20 pub tol: f64,
21 pub solver: String,
22 pub random_state: Option<i32>,
23 pub class_weight: Option<String>,
24 pub multi_class: String,
25 pub warm_start: bool,
26 pub n_jobs: Option<i32>,
27 pub l1_ratio: Option<f64>,
28}
29
30impl Default for PyLogisticRegressionConfig {
31 fn default() -> Self {
32 Self {
33 penalty: "l2".to_string(),
34 c: 1.0,
35 fit_intercept: true,
36 max_iter: 100,
37 tol: 1e-4,
38 solver: "lbfgs".to_string(),
39 random_state: None,
40 class_weight: None,
41 multi_class: "auto".to_string(),
42 warm_start: false,
43 n_jobs: None,
44 l1_ratio: None,
45 }
46 }
47}
48
49impl From<PyLogisticRegressionConfig> for LogisticRegressionConfig {
50 fn from(py_config: PyLogisticRegressionConfig) -> Self {
51 let penalty = match py_config.penalty.as_str() {
53 "l1" => Penalty::L1(1.0 / py_config.c),
54 "l2" => Penalty::L2(1.0 / py_config.c),
55 "elasticnet" => Penalty::ElasticNet {
56 alpha: 1.0 / py_config.c,
57 l1_ratio: py_config.l1_ratio.unwrap_or(0.5),
58 },
59 _ => Penalty::L2(1.0 / py_config.c), };
61
62 let solver = match py_config.solver.as_str() {
64 "lbfgs" => Solver::Lbfgs,
65 "sag" => Solver::Sag,
66 "saga" => Solver::Saga,
67 "newton-cg" => Solver::Newton,
68 _ => Solver::Auto, };
70
71 LogisticRegressionConfig {
72 penalty,
73 solver,
74 max_iter: py_config.max_iter,
75 tol: py_config.tol,
76 fit_intercept: py_config.fit_intercept,
77 random_state: py_config.random_state.map(|s| s as u64),
78 }
79 }
80}
81
82#[pyclass(name = "LogisticRegression")]
109pub struct PyLogisticRegression {
110 py_config: PyLogisticRegressionConfig,
111 fitted_model: Option<LogisticRegression<Trained>>,
112}
113
114#[pymethods]
115impl PyLogisticRegression {
116 #[new]
117 #[allow(clippy::too_many_arguments)]
118 #[pyo3(signature = (penalty="l2", dual=false, tol=1e-4, c=1.0, fit_intercept=true, intercept_scaling=1.0, class_weight=None, random_state=None, solver="lbfgs", max_iter=100, multi_class="auto", verbose=0, warm_start=false, n_jobs=None, l1_ratio=None))]
119 fn new(
120 penalty: &str,
121 dual: bool,
122 tol: f64,
123 c: f64,
124 fit_intercept: bool,
125 intercept_scaling: f64,
126 class_weight: Option<&str>,
127 random_state: Option<i32>,
128 solver: &str,
129 max_iter: usize,
130 multi_class: &str,
131 verbose: i32,
132 warm_start: bool,
133 n_jobs: Option<i32>,
134 l1_ratio: Option<f64>,
135 ) -> Self {
136 let _dual = dual;
138 let _intercept_scaling = intercept_scaling;
139 let _verbose = verbose;
140
141 let py_config = PyLogisticRegressionConfig {
142 penalty: penalty.to_string(),
143 c,
144 fit_intercept,
145 max_iter,
146 tol,
147 solver: solver.to_string(),
148 random_state,
149 class_weight: class_weight.map(|s| s.to_string()),
150 multi_class: multi_class.to_string(),
151 warm_start,
152 n_jobs,
153 l1_ratio,
154 };
155
156 Self {
157 py_config,
158 fitted_model: None,
159 }
160 }
161
162 fn fit(&mut self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<()> {
164 let x_array = pyarray_to_core_array2(x)?;
165 let y_array = pyarray_to_core_array1(y)?;
166
167 validate_fit_arrays(&x_array, &y_array)?;
169
170 let model = LogisticRegression::new()
172 .max_iter(self.py_config.max_iter)
173 .fit_intercept(self.py_config.fit_intercept);
174
175 let model = match self.py_config.penalty.as_str() {
177 "l1" => model.penalty(Penalty::L1(1.0 / self.py_config.c)),
178 "l2" => model.penalty(Penalty::L2(1.0 / self.py_config.c)),
179 "elasticnet" => model.penalty(Penalty::ElasticNet {
180 alpha: 1.0 / self.py_config.c,
181 l1_ratio: self.py_config.l1_ratio.unwrap_or(0.5),
182 }),
183 _ => model, };
185
186 let model = match self.py_config.solver.as_str() {
188 "lbfgs" => model.solver(Solver::Lbfgs),
189 "sag" => model.solver(Solver::Sag),
190 "saga" => model.solver(Solver::Saga),
191 "newton-cg" => model.solver(Solver::Newton),
192 _ => model.solver(Solver::Auto),
193 };
194
195 let model = if let Some(rs) = self.py_config.random_state {
197 model.random_state(rs as u64)
198 } else {
199 model
200 };
201
202 match model.fit(&x_array, &y_array) {
204 Ok(fitted_model) => {
205 self.fitted_model = Some(fitted_model);
206 Ok(())
207 }
208 Err(e) => Err(PyValueError::new_err(format!(
209 "Failed to fit Logistic Regression model: {:?}",
210 e
211 ))),
212 }
213 }
214
215 fn predict(&self, py: Python<'_>, x: PyReadonlyArray2<f64>) -> PyResult<Py<PyArray1<f64>>> {
217 let fitted = self
218 .fitted_model
219 .as_ref()
220 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
221
222 let x_array = pyarray_to_core_array2(x)?;
223 validate_predict_array(&x_array)?;
224
225 match fitted.predict(&x_array) {
226 Ok(predictions) => Ok(core_array1_to_py(py, &predictions)),
227 Err(e) => Err(PyValueError::new_err(format!("Prediction failed: {:?}", e))),
228 }
229 }
230
231 fn predict_proba(
233 &self,
234 py: Python<'_>,
235 x: PyReadonlyArray2<f64>,
236 ) -> PyResult<Py<PyArray2<f64>>> {
237 let fitted = self
238 .fitted_model
239 .as_ref()
240 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
241
242 let x_array = pyarray_to_core_array2(x)?;
243 validate_predict_array(&x_array)?;
244
245 match fitted.predict_proba(&x_array) {
246 Ok(probabilities) => core_array2_to_py(py, &probabilities),
247 Err(e) => Err(PyValueError::new_err(format!(
248 "Probability prediction failed: {:?}",
249 e
250 ))),
251 }
252 }
253
254 #[getter]
256 fn coef_(&self, py: Python<'_>) -> PyResult<Py<PyArray1<f64>>> {
257 let fitted = self
258 .fitted_model
259 .as_ref()
260 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
261
262 Ok(core_array1_to_py(py, fitted.coef()))
263 }
264
265 #[getter]
267 fn intercept_(&self) -> PyResult<f64> {
268 let fitted = self
269 .fitted_model
270 .as_ref()
271 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
272
273 Ok(fitted.intercept().unwrap_or(0.0))
274 }
275
276 #[getter]
278 fn classes_(&self, py: Python<'_>) -> PyResult<Py<PyArray1<f64>>> {
279 let fitted = self
280 .fitted_model
281 .as_ref()
282 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
283
284 Ok(core_array1_to_py(py, fitted.classes()))
285 }
286
287 fn score(&self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<f64> {
289 let fitted = self
290 .fitted_model
291 .as_ref()
292 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
293
294 let x_array = pyarray_to_core_array2(x)?;
295 let y_array = pyarray_to_core_array1(y)?;
296
297 match fitted.score(&x_array, &y_array) {
298 Ok(score) => Ok(score),
299 Err(e) => Err(PyValueError::new_err(format!(
300 "Score calculation failed: {:?}",
301 e
302 ))),
303 }
304 }
305
306 #[getter]
308 fn n_features_in_(&self) -> PyResult<usize> {
309 let fitted = self
310 .fitted_model
311 .as_ref()
312 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
313
314 Ok(fitted.coef().len())
316 }
317
318 fn get_params(&self, py: Python<'_>, deep: Option<bool>) -> PyResult<Py<PyDict>> {
320 let _deep = deep.unwrap_or(true);
321
322 let dict = PyDict::new(py);
323
324 dict.set_item("penalty", &self.py_config.penalty)?;
325 dict.set_item("C", self.py_config.c)?;
326 dict.set_item("fit_intercept", self.py_config.fit_intercept)?;
327 dict.set_item("max_iter", self.py_config.max_iter)?;
328 dict.set_item("tol", self.py_config.tol)?;
329 dict.set_item("solver", &self.py_config.solver)?;
330 dict.set_item("random_state", self.py_config.random_state)?;
331 dict.set_item("class_weight", &self.py_config.class_weight)?;
332 dict.set_item("multi_class", &self.py_config.multi_class)?;
333 dict.set_item("warm_start", self.py_config.warm_start)?;
334 dict.set_item("n_jobs", self.py_config.n_jobs)?;
335 dict.set_item("l1_ratio", self.py_config.l1_ratio)?;
336
337 Ok(dict.into())
338 }
339
340 fn set_params(&mut self, kwargs: &Bound<'_, PyDict>) -> PyResult<()> {
342 if let Some(penalty) = kwargs.get_item("penalty")? {
344 let penalty_str: String = penalty.extract()?;
345 self.py_config.penalty = penalty_str;
346 }
347 if let Some(c) = kwargs.get_item("C")? {
348 self.py_config.c = c.extract()?;
349 }
350 if let Some(fit_intercept) = kwargs.get_item("fit_intercept")? {
351 self.py_config.fit_intercept = fit_intercept.extract()?;
352 }
353 if let Some(max_iter) = kwargs.get_item("max_iter")? {
354 self.py_config.max_iter = max_iter.extract()?;
355 }
356 if let Some(tol) = kwargs.get_item("tol")? {
357 self.py_config.tol = tol.extract()?;
358 }
359 if let Some(solver) = kwargs.get_item("solver")? {
360 let solver_str: String = solver.extract()?;
361 self.py_config.solver = solver_str;
362 }
363 if let Some(random_state) = kwargs.get_item("random_state")? {
364 self.py_config.random_state = random_state.extract()?;
365 }
366 if let Some(class_weight) = kwargs.get_item("class_weight")? {
367 let weight_str: Option<String> = class_weight.extract()?;
368 self.py_config.class_weight = weight_str;
369 }
370 if let Some(multi_class) = kwargs.get_item("multi_class")? {
371 let multi_class_str: String = multi_class.extract()?;
372 self.py_config.multi_class = multi_class_str;
373 }
374 if let Some(warm_start) = kwargs.get_item("warm_start")? {
375 self.py_config.warm_start = warm_start.extract()?;
376 }
377 if let Some(n_jobs) = kwargs.get_item("n_jobs")? {
378 self.py_config.n_jobs = n_jobs.extract()?;
379 }
380 if let Some(l1_ratio) = kwargs.get_item("l1_ratio")? {
381 self.py_config.l1_ratio = l1_ratio.extract()?;
382 }
383
384 self.fitted_model = None;
386
387 Ok(())
388 }
389
390 fn __repr__(&self) -> String {
392 format!(
393 "LogisticRegression(penalty='{}', C={}, fit_intercept={}, max_iter={}, tol={}, solver='{}', random_state={:?}, multi_class='{}')",
394 self.py_config.penalty,
395 self.py_config.c,
396 self.py_config.fit_intercept,
397 self.py_config.max_iter,
398 self.py_config.tol,
399 self.py_config.solver,
400 self.py_config.random_state,
401 self.py_config.multi_class
402 )
403 }
404}