Skip to main content

sklears_python/
neural_network.rs

1//! Python bindings for neural network models
2//!
3//! This module provides PyO3-based Python bindings for sklears neural network algorithms,
4//! including Multi-Layer Perceptron (MLP) classifiers and regressors.
5
6use crate::linear::common::{core_array1_to_py, core_array2_to_py};
7use crate::utils::{numpy_to_ndarray1, numpy_to_ndarray2};
8use numpy::{PyArray1, PyArray2};
9use pyo3::exceptions::{PyRuntimeError, PyValueError};
10use pyo3::prelude::*;
11use sklears_core::traits::{Fit, Predict};
12use sklears_neural::solvers::LearningRateSchedule;
13use sklears_neural::{Activation, MLPClassifier, MLPRegressor, Solver};
14
15/// Python wrapper for MLP Classifier
16#[pyclass(name = "MLPClassifier")]
17pub struct PyMLPClassifier {
18    inner: Option<MLPClassifier<sklears_core::traits::Untrained>>,
19    trained: Option<MLPClassifier<sklears_neural::TrainedMLPClassifier>>,
20}
21
22#[pymethods]
23impl PyMLPClassifier {
24    #[new]
25    #[allow(clippy::too_many_arguments)]
26    #[pyo3(signature = (
27        hidden_layer_sizes=None,
28        activation="relu",
29        solver="adam",
30        alpha=0.0001,
31        batch_size=None,
32        learning_rate="constant",
33        learning_rate_init=0.001,
34        power_t=0.5,
35        max_iter=200,
36        shuffle=true,
37        random_state=None,
38        tol=1e-4,
39        verbose=false,
40        warm_start=false,
41        momentum=0.9,
42        nesterovs_momentum=true,
43        early_stopping=false,
44        validation_fraction=0.1,
45        beta_1=0.9,
46        beta_2=0.999,
47        epsilon=1e-8,
48        n_iter_no_change=10,
49        max_fun=15000
50    ))]
51    fn new(
52        hidden_layer_sizes: Option<Vec<usize>>,
53        activation: &str,
54        solver: &str,
55        alpha: f64,
56        batch_size: Option<usize>,
57        learning_rate: &str,
58        learning_rate_init: f64,
59        power_t: f64,
60        max_iter: usize,
61        shuffle: bool,
62        random_state: Option<u64>,
63        tol: f64,
64        verbose: bool,
65        warm_start: bool,
66        momentum: f64,
67        nesterovs_momentum: bool,
68        early_stopping: bool,
69        validation_fraction: f64,
70        beta_1: f64,
71        beta_2: f64,
72        epsilon: f64,
73        n_iter_no_change: usize,
74        max_fun: usize,
75    ) -> PyResult<Self> {
76        let activation = match activation {
77            "identity" => Activation::Identity,
78            "logistic" => Activation::Logistic,
79            "tanh" => Activation::Tanh,
80            "relu" => Activation::Relu,
81            "elu" => Activation::Elu,
82            "swish" => Activation::Swish,
83            "gelu" => Activation::Gelu,
84            "mish" => Activation::Mish,
85            "leaky_relu" => Activation::LeakyRelu,
86            "prelu" => Activation::PRelu,
87            _ => {
88                return Err(PyValueError::new_err(format!(
89                    "Unknown activation: {}",
90                    activation
91                )))
92            }
93        };
94
95        let solver = match solver {
96            "lbfgs" => Solver::Lbfgs,
97            "sgd" => Solver::Sgd,
98            "adam" => Solver::Adam,
99            _ => return Err(PyValueError::new_err(format!("Unknown solver: {}", solver))),
100        };
101
102        let learning_rate_schedule = match learning_rate {
103            "constant" => LearningRateSchedule::Constant,
104            "invscaling" => LearningRateSchedule::InvScaling,
105            "adaptive" => LearningRateSchedule::Adaptive,
106            _ => {
107                return Err(PyValueError::new_err(format!(
108                    "Unknown learning rate schedule: {}",
109                    learning_rate
110                )))
111            }
112        };
113
114        let hidden_sizes = hidden_layer_sizes.unwrap_or_else(|| vec![100]);
115
116        let mut mlp = MLPClassifier::new();
117        mlp.hidden_layer_sizes = hidden_sizes;
118        mlp.activation = activation;
119        mlp.solver = solver;
120        mlp.alpha = alpha;
121        mlp.batch_size = batch_size;
122        mlp.learning_rate = learning_rate_schedule;
123        mlp.learning_rate_init = learning_rate_init;
124        mlp.power_t = power_t;
125        mlp.max_iter = max_iter;
126        mlp.shuffle = shuffle;
127        mlp.random_state = random_state;
128        mlp.tol = tol;
129        mlp.verbose = verbose;
130        mlp.warm_start = warm_start;
131        mlp.momentum = momentum;
132        mlp.nesterovs_momentum = nesterovs_momentum;
133        mlp.early_stopping = early_stopping;
134        mlp.validation_fraction = validation_fraction;
135        mlp.beta_1 = beta_1;
136        mlp.beta_2 = beta_2;
137        mlp.epsilon = epsilon;
138        mlp.n_iter_no_change = n_iter_no_change;
139        mlp.max_fun = max_fun;
140
141        Ok(Self {
142            inner: Some(mlp),
143            trained: None,
144        })
145    }
146
147    /// Fit the MLP classifier
148    fn fit(&mut self, x: &Bound<'_, PyArray2<f64>>, y: &Bound<'_, PyArray1<f64>>) -> PyResult<()> {
149        let x_array = numpy_to_ndarray2(x)?;
150        let y_array = numpy_to_ndarray1(y)?;
151
152        // Convert y to integer vector for classification
153        let y_int: Vec<usize> = y_array.iter().map(|&val| val as usize).collect();
154
155        let model = self.inner.take().ok_or_else(|| {
156            PyRuntimeError::new_err("Model has already been fitted or was not initialized")
157        })?;
158
159        match model.fit(&x_array, &y_int) {
160            Ok(trained_model) => {
161                self.trained = Some(trained_model);
162                Ok(())
163            }
164            Err(e) => Err(PyRuntimeError::new_err(format!(
165                "Failed to fit model: {}",
166                e
167            ))),
168        }
169    }
170
171    /// Make predictions using the fitted model
172    fn predict<'py>(
173        &self,
174        py: Python<'py>,
175        x: &Bound<'py, PyArray2<f64>>,
176    ) -> PyResult<Py<PyArray1<f64>>> {
177        let trained_model = self.trained.as_ref().ok_or_else(|| {
178            PyRuntimeError::new_err("Model must be fitted before making predictions")
179        })?;
180
181        let x_array = numpy_to_ndarray2(x)?;
182
183        match trained_model.predict(&x_array) {
184            Ok(predictions) => {
185                let predictions_f64: Vec<f64> = predictions.iter().map(|&x| x as f64).collect();
186                Ok(PyArray1::from_vec(py, predictions_f64).unbind())
187            }
188            Err(e) => Err(PyRuntimeError::new_err(format!("Prediction failed: {}", e))),
189        }
190    }
191
192    /// Predict class probabilities
193    fn predict_proba<'py>(
194        &self,
195        py: Python<'py>,
196        x: &Bound<'py, PyArray2<f64>>,
197    ) -> PyResult<Py<PyArray2<f64>>> {
198        let trained_model = self.trained.as_ref().ok_or_else(|| {
199            PyRuntimeError::new_err("Model must be fitted before making predictions")
200        })?;
201
202        let x_array = numpy_to_ndarray2(x)?;
203
204        match trained_model.predict_proba(&x_array) {
205            Ok(probabilities) => Ok(core_array2_to_py(py, &probabilities)?),
206            Err(e) => Err(PyRuntimeError::new_err(format!(
207                "Probability prediction failed: {}",
208                e
209            ))),
210        }
211    }
212
213    /// Get the loss after training
214    fn loss_(&self) -> PyResult<f64> {
215        let trained_model = self
216            .trained
217            .as_ref()
218            .ok_or_else(|| PyRuntimeError::new_err("Model must be fitted before accessing loss"))?;
219
220        Ok(trained_model.loss())
221    }
222
223    /// Get number of iterations
224    fn n_iter_(&self) -> PyResult<usize> {
225        let trained_model = self.trained.as_ref().ok_or_else(|| {
226            PyRuntimeError::new_err("Model must be fitted before accessing n_iter")
227        })?;
228
229        Ok(trained_model.n_iter())
230    }
231
232    fn __repr__(&self) -> String {
233        if self.trained.is_some() {
234            "MLPClassifier(fitted=True)".to_string()
235        } else {
236            "MLPClassifier(fitted=False)".to_string()
237        }
238    }
239}
240
241/// Python wrapper for MLP Regressor
242#[pyclass(name = "MLPRegressor")]
243pub struct PyMLPRegressor {
244    inner: Option<MLPRegressor<sklears_core::traits::Untrained>>,
245    trained: Option<MLPRegressor<sklears_neural::TrainedMLPRegressor>>,
246}
247
248#[pymethods]
249impl PyMLPRegressor {
250    #[new]
251    #[allow(clippy::too_many_arguments)]
252    #[pyo3(signature = (
253        hidden_layer_sizes=None,
254        activation="relu",
255        solver="adam",
256        alpha=0.0001,
257        batch_size=None,
258        learning_rate="constant",
259        learning_rate_init=0.001,
260        power_t=0.5,
261        max_iter=200,
262        shuffle=true,
263        random_state=None,
264        tol=1e-4,
265        verbose=false,
266        warm_start=false,
267        momentum=0.9,
268        nesterovs_momentum=true,
269        early_stopping=false,
270        validation_fraction=0.1,
271        beta_1=0.9,
272        beta_2=0.999,
273        epsilon=1e-8,
274        n_iter_no_change=10,
275        max_fun=15000
276    ))]
277    fn new(
278        hidden_layer_sizes: Option<Vec<usize>>,
279        activation: &str,
280        solver: &str,
281        alpha: f64,
282        batch_size: Option<usize>,
283        learning_rate: &str,
284        learning_rate_init: f64,
285        power_t: f64,
286        max_iter: usize,
287        shuffle: bool,
288        random_state: Option<u64>,
289        tol: f64,
290        verbose: bool,
291        warm_start: bool,
292        momentum: f64,
293        nesterovs_momentum: bool,
294        early_stopping: bool,
295        validation_fraction: f64,
296        beta_1: f64,
297        beta_2: f64,
298        epsilon: f64,
299        n_iter_no_change: usize,
300        max_fun: usize,
301    ) -> PyResult<Self> {
302        let activation = match activation {
303            "identity" => Activation::Identity,
304            "logistic" => Activation::Logistic,
305            "tanh" => Activation::Tanh,
306            "relu" => Activation::Relu,
307            "elu" => Activation::Elu,
308            "swish" => Activation::Swish,
309            "gelu" => Activation::Gelu,
310            "mish" => Activation::Mish,
311            "leaky_relu" => Activation::LeakyRelu,
312            "prelu" => Activation::PRelu,
313            _ => {
314                return Err(PyValueError::new_err(format!(
315                    "Unknown activation: {}",
316                    activation
317                )))
318            }
319        };
320
321        let solver = match solver {
322            "lbfgs" => Solver::Lbfgs,
323            "sgd" => Solver::Sgd,
324            "adam" => Solver::Adam,
325            _ => return Err(PyValueError::new_err(format!("Unknown solver: {}", solver))),
326        };
327
328        let learning_rate_schedule = match learning_rate {
329            "constant" => LearningRateSchedule::Constant,
330            "invscaling" => LearningRateSchedule::InvScaling,
331            "adaptive" => LearningRateSchedule::Adaptive,
332            _ => {
333                return Err(PyValueError::new_err(format!(
334                    "Unknown learning rate schedule: {}",
335                    learning_rate
336                )))
337            }
338        };
339
340        let hidden_sizes = hidden_layer_sizes.unwrap_or_else(|| vec![100]);
341
342        let mut mlp = MLPRegressor::new();
343        mlp.hidden_layer_sizes = hidden_sizes;
344        mlp.activation = activation;
345        mlp.solver = solver;
346        mlp.alpha = alpha;
347        mlp.batch_size = batch_size;
348        mlp.learning_rate = learning_rate_schedule;
349        mlp.learning_rate_init = learning_rate_init;
350        mlp.power_t = power_t;
351        mlp.max_iter = max_iter;
352        mlp.shuffle = shuffle;
353        mlp.random_state = random_state;
354        mlp.tol = tol;
355        mlp.verbose = verbose;
356        mlp.warm_start = warm_start;
357        mlp.momentum = momentum;
358        mlp.nesterovs_momentum = nesterovs_momentum;
359        mlp.early_stopping = early_stopping;
360        mlp.validation_fraction = validation_fraction;
361        mlp.beta_1 = beta_1;
362        mlp.beta_2 = beta_2;
363        mlp.epsilon = epsilon;
364        mlp.n_iter_no_change = n_iter_no_change;
365        mlp.max_fun = max_fun;
366
367        Ok(Self {
368            inner: Some(mlp),
369            trained: None,
370        })
371    }
372
373    /// Fit the MLP regressor
374    fn fit(&mut self, x: &Bound<'_, PyArray2<f64>>, y: &Bound<'_, PyArray1<f64>>) -> PyResult<()> {
375        let x_array = numpy_to_ndarray2(x)?;
376        let y_array_1d = numpy_to_ndarray1(y)?;
377
378        // Convert y from 1D to 2D array (n_samples, 1)
379        let y_array = y_array_1d.insert_axis(scirs2_core::ndarray::Axis(1));
380
381        let model = self.inner.take().ok_or_else(|| {
382            PyRuntimeError::new_err("Model has already been fitted or was not initialized")
383        })?;
384
385        match model.fit(&x_array, &y_array) {
386            Ok(trained_model) => {
387                self.trained = Some(trained_model);
388                Ok(())
389            }
390            Err(e) => Err(PyRuntimeError::new_err(format!(
391                "Failed to fit model: {}",
392                e
393            ))),
394        }
395    }
396
397    /// Make predictions using the fitted model
398    fn predict<'py>(
399        &self,
400        py: Python<'py>,
401        x: &Bound<'py, PyArray2<f64>>,
402    ) -> PyResult<Py<PyArray1<f64>>> {
403        let trained_model = self.trained.as_ref().ok_or_else(|| {
404            PyRuntimeError::new_err("Model must be fitted before making predictions")
405        })?;
406
407        let x_array = numpy_to_ndarray2(x)?;
408
409        match trained_model.predict(&x_array) {
410            Ok(predictions_2d) => {
411                // Convert from Array2 (n_samples, 1) to Array1 (n_samples,)
412                let predictions_1d = predictions_2d
413                    .index_axis(scirs2_core::ndarray::Axis(1), 0)
414                    .to_owned();
415                Ok(core_array1_to_py(py, &predictions_1d))
416            }
417            Err(e) => Err(PyRuntimeError::new_err(format!("Prediction failed: {}", e))),
418        }
419    }
420
421    /// Get the loss after training
422    fn loss_(&self) -> PyResult<f64> {
423        let trained_model = self
424            .trained
425            .as_ref()
426            .ok_or_else(|| PyRuntimeError::new_err("Model must be fitted before accessing loss"))?;
427
428        Ok(trained_model.loss())
429    }
430
431    /// Get number of iterations
432    fn n_iter_(&self) -> PyResult<usize> {
433        let trained_model = self.trained.as_ref().ok_or_else(|| {
434            PyRuntimeError::new_err("Model must be fitted before accessing n_iter")
435        })?;
436
437        Ok(trained_model.n_iter())
438    }
439
440    fn __repr__(&self) -> String {
441        if self.trained.is_some() {
442            "MLPRegressor(fitted=True)".to_string()
443        } else {
444            "MLPRegressor(fitted=False)".to_string()
445        }
446    }
447}