1use super::common::*;
8use pyo3::types::PyDict;
9use pyo3::Bound;
10use sklears_core::traits::{Fit, Predict, Score, Trained};
11use sklears_linear::{LinearRegression, LinearRegressionConfig, Penalty};
12
13#[derive(Debug, Clone)]
15pub struct PyLassoConfig {
16 pub alpha: f64,
17 pub fit_intercept: bool,
18 pub copy_x: bool,
19 pub max_iter: usize,
20 pub tol: f64,
21 pub warm_start: bool,
22 pub positive: bool,
23 pub random_state: Option<i32>,
24 pub selection: String,
25}
26
27impl Default for PyLassoConfig {
28 fn default() -> Self {
29 Self {
30 alpha: 1.0,
31 fit_intercept: true,
32 copy_x: true,
33 max_iter: 1000,
34 tol: 1e-4,
35 warm_start: false,
36 positive: false,
37 random_state: None,
38 selection: "cyclic".to_string(),
39 }
40 }
41}
42
43impl From<PyLassoConfig> for LinearRegressionConfig {
44 fn from(py_config: PyLassoConfig) -> Self {
45 LinearRegressionConfig {
46 fit_intercept: py_config.fit_intercept,
47 penalty: Penalty::L1(py_config.alpha),
48 max_iter: py_config.max_iter,
49 tol: py_config.tol,
50 warm_start: py_config.warm_start,
51 ..Default::default()
52 }
53 }
54}
55
56#[pyclass(name = "Lasso")]
58pub struct PyLasso {
59 py_config: PyLassoConfig,
61 fitted_model: Option<LinearRegression<Trained>>,
63}
64
65#[pymethods]
66impl PyLasso {
67 #[new]
68 #[allow(clippy::too_many_arguments)]
69 #[pyo3(signature = (alpha=1.0, fit_intercept=true, copy_x=true, max_iter=1000, tol=1e-4, warm_start=false, positive=false, random_state=None, selection="cyclic"))]
70 fn new(
71 alpha: f64,
72 fit_intercept: bool,
73 copy_x: bool,
74 max_iter: usize,
75 tol: f64,
76 warm_start: bool,
77 positive: bool,
78 random_state: Option<i32>,
79 selection: &str,
80 ) -> Self {
81 let py_config = PyLassoConfig {
82 alpha,
83 fit_intercept,
84 copy_x,
85 max_iter,
86 tol,
87 warm_start,
88 positive,
89 random_state,
90 selection: selection.to_string(),
91 };
92
93 Self {
94 py_config,
95 fitted_model: None,
96 }
97 }
98
99 fn fit(&mut self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<()> {
101 let x_array = pyarray_to_core_array2(x)?;
102 let y_array = pyarray_to_core_array1(y)?;
103
104 validate_fit_arrays(&x_array, &y_array)?;
106
107 let model = LinearRegression::lasso(self.py_config.alpha)
109 .fit_intercept(self.py_config.fit_intercept);
110
111 match model.fit(&x_array, &y_array) {
113 Ok(fitted_model) => {
114 self.fitted_model = Some(fitted_model);
115 Ok(())
116 }
117 Err(e) => Err(PyValueError::new_err(format!(
118 "Failed to fit Lasso model: {:?}",
119 e
120 ))),
121 }
122 }
123
124 fn predict(&self, py: Python<'_>, x: PyReadonlyArray2<f64>) -> PyResult<Py<PyArray1<f64>>> {
126 let fitted = self
127 .fitted_model
128 .as_ref()
129 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
130
131 let x_array = pyarray_to_core_array2(x)?;
132 validate_predict_array(&x_array)?;
133
134 match fitted.predict(&x_array) {
135 Ok(predictions) => Ok(core_array1_to_py(py, &predictions)),
136 Err(e) => Err(PyValueError::new_err(format!("Prediction failed: {:?}", e))),
137 }
138 }
139
140 #[getter]
142 fn coef_(&self, py: Python<'_>) -> PyResult<Py<PyArray1<f64>>> {
143 let fitted = self
144 .fitted_model
145 .as_ref()
146 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
147
148 Ok(core_array1_to_py(py, fitted.coef()))
149 }
150
151 #[getter]
153 fn intercept_(&self) -> PyResult<f64> {
154 let fitted = self
155 .fitted_model
156 .as_ref()
157 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
158
159 Ok(fitted.intercept().unwrap_or(0.0))
160 }
161
162 fn score(&self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<f64> {
164 let fitted = self
165 .fitted_model
166 .as_ref()
167 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
168
169 let x_array = pyarray_to_core_array2(x)?;
170 let y_array = pyarray_to_core_array1(y)?;
171
172 match fitted.score(&x_array, &y_array) {
173 Ok(score) => Ok(score),
174 Err(e) => Err(PyValueError::new_err(format!(
175 "Score calculation failed: {:?}",
176 e
177 ))),
178 }
179 }
180
181 #[getter]
183 fn n_features_in_(&self) -> PyResult<usize> {
184 let fitted = self
185 .fitted_model
186 .as_ref()
187 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
188
189 Ok(fitted.coef().len())
191 }
192
193 fn get_params(&self, py: Python<'_>, deep: Option<bool>) -> PyResult<Py<PyDict>> {
195 let _deep = deep.unwrap_or(true);
196
197 let dict = PyDict::new(py);
198
199 dict.set_item("alpha", self.py_config.alpha)?;
200 dict.set_item("fit_intercept", self.py_config.fit_intercept)?;
201 dict.set_item("copy_X", self.py_config.copy_x)?;
202 dict.set_item("max_iter", self.py_config.max_iter)?;
203 dict.set_item("tol", self.py_config.tol)?;
204 dict.set_item("warm_start", self.py_config.warm_start)?;
205 dict.set_item("positive", self.py_config.positive)?;
206 dict.set_item("random_state", self.py_config.random_state)?;
207 dict.set_item("selection", &self.py_config.selection)?;
208
209 Ok(dict.into())
210 }
211
212 fn set_params(&mut self, kwargs: &Bound<'_, PyDict>) -> PyResult<()> {
214 if let Some(alpha) = kwargs.get_item("alpha")? {
216 self.py_config.alpha = alpha.extract()?;
217 }
218 if let Some(fit_intercept) = kwargs.get_item("fit_intercept")? {
219 self.py_config.fit_intercept = fit_intercept.extract()?;
220 }
221 if let Some(copy_x) = kwargs.get_item("copy_X")? {
222 self.py_config.copy_x = copy_x.extract()?;
223 }
224 if let Some(max_iter) = kwargs.get_item("max_iter")? {
225 self.py_config.max_iter = max_iter.extract()?;
226 }
227 if let Some(tol) = kwargs.get_item("tol")? {
228 self.py_config.tol = tol.extract()?;
229 }
230 if let Some(warm_start) = kwargs.get_item("warm_start")? {
231 self.py_config.warm_start = warm_start.extract()?;
232 }
233 if let Some(positive) = kwargs.get_item("positive")? {
234 self.py_config.positive = positive.extract()?;
235 }
236 if let Some(random_state) = kwargs.get_item("random_state")? {
237 self.py_config.random_state = random_state.extract()?;
238 }
239 if let Some(selection) = kwargs.get_item("selection")? {
240 let selection_str: String = selection.extract()?;
241 self.py_config.selection = selection_str;
242 }
243
244 self.fitted_model = None;
246
247 Ok(())
248 }
249
250 fn __repr__(&self) -> String {
252 format!(
253 "Lasso(alpha={}, fit_intercept={}, copy_X={}, max_iter={}, tol={}, warm_start={}, positive={}, random_state={:?}, selection='{}')",
254 self.py_config.alpha,
255 self.py_config.fit_intercept,
256 self.py_config.copy_x,
257 self.py_config.max_iter,
258 self.py_config.tol,
259 self.py_config.warm_start,
260 self.py_config.positive,
261 self.py_config.random_state,
262 self.py_config.selection
263 )
264 }
265}