1use super::common::*;
8use numpy::IntoPyArray;
9use pyo3::types::PyDict;
10use pyo3::Bound;
11use sklears_core::traits::{Fit, Predict, Score, Trained};
12use sklears_linear::{LinearRegression, LinearRegressionConfig, Penalty};
13
14#[derive(Debug, Clone)]
16pub struct PyLassoConfig {
17 pub alpha: f64,
18 pub fit_intercept: bool,
19 pub copy_x: bool,
20 pub max_iter: usize,
21 pub tol: f64,
22 pub warm_start: bool,
23 pub positive: bool,
24 pub random_state: Option<i32>,
25 pub selection: String,
26}
27
28impl Default for PyLassoConfig {
29 fn default() -> Self {
30 Self {
31 alpha: 1.0,
32 fit_intercept: true,
33 copy_x: true,
34 max_iter: 1000,
35 tol: 1e-4,
36 warm_start: false,
37 positive: false,
38 random_state: None,
39 selection: "cyclic".to_string(),
40 }
41 }
42}
43
44impl From<PyLassoConfig> for LinearRegressionConfig {
45 fn from(py_config: PyLassoConfig) -> Self {
46 let mut config = LinearRegressionConfig::default();
47 config.fit_intercept = py_config.fit_intercept;
48 config.penalty = Penalty::L1(py_config.alpha);
49 config.max_iter = py_config.max_iter;
50 config.tol = py_config.tol;
51 config.warm_start = py_config.warm_start;
52 config
53 }
54}
55
56#[pyclass(name = "Lasso")]
58pub struct PyLasso {
59 py_config: PyLassoConfig,
61 fitted_model: Option<LinearRegression<Trained>>,
63}
64
65#[pymethods]
66impl PyLasso {
67 #[new]
68 #[pyo3(signature = (alpha=1.0, fit_intercept=true, copy_x=true, max_iter=1000, tol=1e-4, warm_start=false, positive=false, random_state=None, selection="cyclic"))]
69 fn new(
70 alpha: f64,
71 fit_intercept: bool,
72 copy_x: bool,
73 max_iter: usize,
74 tol: f64,
75 warm_start: bool,
76 positive: bool,
77 random_state: Option<i32>,
78 selection: &str,
79 ) -> Self {
80 let py_config = PyLassoConfig {
81 alpha,
82 fit_intercept,
83 copy_x,
84 max_iter,
85 tol,
86 warm_start,
87 positive,
88 random_state,
89 selection: selection.to_string(),
90 };
91
92 Self {
93 py_config,
94 fitted_model: None,
95 }
96 }
97
98 fn fit(&mut self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<()> {
100 let x_array = x.as_array().to_owned();
101 let y_array = y.as_array().to_owned();
102
103 validate_fit_arrays(&x_array, &y_array)?;
105
106 let model = LinearRegression::lasso(self.py_config.alpha)
108 .fit_intercept(self.py_config.fit_intercept);
109
110 match model.fit(&x_array, &y_array) {
112 Ok(fitted_model) => {
113 self.fitted_model = Some(fitted_model);
114 Ok(())
115 }
116 Err(e) => Err(PyValueError::new_err(format!(
117 "Failed to fit Lasso model: {:?}",
118 e
119 ))),
120 }
121 }
122
123 fn predict(&self, x: PyReadonlyArray2<f64>) -> PyResult<Py<PyArray1<f64>>> {
125 let fitted = self
126 .fitted_model
127 .as_ref()
128 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
129
130 let x_array = x.as_array().to_owned();
131 validate_predict_array(&x_array)?;
132
133 match fitted.predict(&x_array) {
134 Ok(predictions) => {
135 let py = unsafe { Python::assume_attached() };
136 Ok(predictions.into_pyarray(py).into())
137 }
138 Err(e) => Err(PyValueError::new_err(format!("Prediction failed: {:?}", e))),
139 }
140 }
141
142 #[getter]
144 fn coef_(&self) -> PyResult<Py<PyArray1<f64>>> {
145 let fitted = self
146 .fitted_model
147 .as_ref()
148 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
149
150 let py = unsafe { Python::assume_attached() };
151 Ok(fitted.coef().clone().into_pyarray(py).into())
152 }
153
154 #[getter]
156 fn intercept_(&self) -> PyResult<f64> {
157 let fitted = self
158 .fitted_model
159 .as_ref()
160 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
161
162 Ok(fitted.intercept().unwrap_or(0.0))
163 }
164
165 fn score(&self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<f64> {
167 let fitted = self
168 .fitted_model
169 .as_ref()
170 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
171
172 let x_array = x.as_array().to_owned();
173 let y_array = y.as_array().to_owned();
174
175 match fitted.score(&x_array, &y_array) {
176 Ok(score) => Ok(score),
177 Err(e) => Err(PyValueError::new_err(format!(
178 "Score calculation failed: {:?}",
179 e
180 ))),
181 }
182 }
183
184 #[getter]
186 fn n_features_in_(&self) -> PyResult<usize> {
187 let fitted = self
188 .fitted_model
189 .as_ref()
190 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
191
192 Ok(fitted.coef().len())
194 }
195
196 fn get_params(&self, deep: Option<bool>) -> PyResult<Py<PyDict>> {
198 let _deep = deep.unwrap_or(true);
199
200 let py = unsafe { Python::assume_attached() };
201 let dict = PyDict::new(py);
202
203 dict.set_item("alpha", self.py_config.alpha)?;
204 dict.set_item("fit_intercept", self.py_config.fit_intercept)?;
205 dict.set_item("copy_X", self.py_config.copy_x)?;
206 dict.set_item("max_iter", self.py_config.max_iter)?;
207 dict.set_item("tol", self.py_config.tol)?;
208 dict.set_item("warm_start", self.py_config.warm_start)?;
209 dict.set_item("positive", self.py_config.positive)?;
210 dict.set_item("random_state", self.py_config.random_state)?;
211 dict.set_item("selection", &self.py_config.selection)?;
212
213 Ok(dict.into())
214 }
215
216 fn set_params(&mut self, kwargs: &Bound<'_, PyDict>) -> PyResult<()> {
218 if let Some(alpha) = kwargs.get_item("alpha")? {
220 self.py_config.alpha = alpha.extract()?;
221 }
222 if let Some(fit_intercept) = kwargs.get_item("fit_intercept")? {
223 self.py_config.fit_intercept = fit_intercept.extract()?;
224 }
225 if let Some(copy_x) = kwargs.get_item("copy_X")? {
226 self.py_config.copy_x = copy_x.extract()?;
227 }
228 if let Some(max_iter) = kwargs.get_item("max_iter")? {
229 self.py_config.max_iter = max_iter.extract()?;
230 }
231 if let Some(tol) = kwargs.get_item("tol")? {
232 self.py_config.tol = tol.extract()?;
233 }
234 if let Some(warm_start) = kwargs.get_item("warm_start")? {
235 self.py_config.warm_start = warm_start.extract()?;
236 }
237 if let Some(positive) = kwargs.get_item("positive")? {
238 self.py_config.positive = positive.extract()?;
239 }
240 if let Some(random_state) = kwargs.get_item("random_state")? {
241 self.py_config.random_state = random_state.extract()?;
242 }
243 if let Some(selection) = kwargs.get_item("selection")? {
244 let selection_str: String = selection.extract()?;
245 self.py_config.selection = selection_str;
246 }
247
248 self.fitted_model = None;
250
251 Ok(())
252 }
253
254 fn __repr__(&self) -> String {
256 format!(
257 "Lasso(alpha={}, fit_intercept={}, copy_X={}, max_iter={}, tol={}, warm_start={}, positive={}, random_state={:?}, selection='{}')",
258 self.py_config.alpha,
259 self.py_config.fit_intercept,
260 self.py_config.copy_x,
261 self.py_config.max_iter,
262 self.py_config.tol,
263 self.py_config.warm_start,
264 self.py_config.positive,
265 self.py_config.random_state,
266 self.py_config.selection
267 )
268 }
269}