Skip to main content

sklears_python/
lib.rs

1#![allow(missing_docs)]
2#![allow(clippy::too_many_arguments)]
3#![allow(clippy::type_complexity)]
4//! Python bindings for the sklears machine learning library
5//!
6//! This crate provides PyO3-based Python bindings for sklears, enabling
7//! seamless integration with the Python ecosystem while maintaining
8//! Rust's performance advantages.
9//!
10//! # Features
11//!
12//! - Drop-in replacement for scikit-learn's most common algorithms
13//! - Pure Rust implementation with ongoing performance optimization
14//! - Full NumPy array compatibility
15//! - Comprehensive error handling with Python exceptions
16//! - Memory-safe operations with automatic reference counting
17//!
18//! # Example
19//!
20//! ```python
21//! import sklears_python as skl
22//! import numpy as np
23//!
24//! # Create sample data
25//! X = np.random.randn(100, 4)
26//! y = np.random.randn(100)
27//!
28//! # Train a linear regression model
29//! model = skl.LinearRegression()
30//! model.fit(X, y)
31//! predictions = model.predict(X)
32//! ```
33
34#[allow(unused_imports)]
35use pyo3::prelude::*;
36
37// Import modules
38mod clustering;
39mod datasets;
40mod ensemble;
41mod linear;
42// mod metrics; // TODO: Needs refactoring to use sklears-metrics directly
43mod model_selection;
44mod naive_bayes;
45mod neural_network;
46// mod preprocessing; // Temporarily disabled to test ensemble
47mod tree;
48mod utils;
49
50// Re-export main classes
51pub use clustering::*;
52pub use ensemble::*;
53pub use linear::*;
54// pub use metrics::*; // TODO: Needs refactoring
55pub use model_selection::*;
56pub use naive_bayes::*;
57pub use neural_network::*;
58// pub use preprocessing::*; // Temporarily disabled to test ensemble
59pub use tree::*;
60pub use utils::*;
61
62/// Python module for sklears machine learning library
63#[pymodule]
64fn sklears_python(m: &Bound<'_, PyModule>) -> PyResult<()> {
65    // Set module metadata
66    m.add("__version__", "0.1.0-beta.1")?;
67    m.add(
68        "__doc__",
69        "High-performance machine learning library with scikit-learn compatibility",
70    )?;
71
72    // Linear models
73    m.add_class::<linear::PyLinearRegression>()?;
74    m.add_class::<linear::PyRidge>()?;
75    m.add_class::<linear::PyLasso>()?;
76    m.add_class::<linear::PyElasticNet>()?;
77    m.add_class::<linear::PyBayesianRidge>()?;
78    m.add_class::<linear::PyARDRegression>()?;
79    m.add_class::<linear::PyLogisticRegression>()?;
80
81    // Ensemble methods
82    m.add_class::<ensemble::PyGradientBoostingClassifier>()?;
83    m.add_class::<ensemble::PyGradientBoostingRegressor>()?;
84    m.add_class::<ensemble::PyAdaBoostClassifier>()?;
85    m.add_class::<ensemble::PyVotingClassifier>()?;
86    m.add_class::<ensemble::PyBaggingClassifier>()?;
87
88    // Neural networks
89    m.add_class::<neural_network::PyMLPClassifier>()?;
90    m.add_class::<neural_network::PyMLPRegressor>()?;
91
92    // Tree-based models - Temporarily disabled to test ensemble
93    // m.add_class::<tree::PyDecisionTreeClassifier>()?;
94    // m.add_class::<tree::PyDecisionTreeRegressor>()?;
95    // m.add_class::<tree::PyRandomForestClassifier>()?;
96    // m.add_class::<tree::PyRandomForestRegressor>()?;
97
98    // Naive Bayes
99    m.add_class::<naive_bayes::PyGaussianNB>()?;
100    m.add_class::<naive_bayes::PyMultinomialNB>()?;
101    m.add_class::<naive_bayes::PyBernoulliNB>()?;
102    m.add_class::<naive_bayes::PyComplementNB>()?;
103
104    // Clustering
105    m.add_class::<clustering::PyKMeans>()?;
106    m.add_class::<clustering::PyDBSCAN>()?;
107
108    // Preprocessing - Temporarily disabled to test ensemble
109    // m.add_class::<preprocessing::PyStandardScaler>()?;
110    // m.add_class::<preprocessing::PyMinMaxScaler>()?;
111    // m.add_class::<preprocessing::PyLabelEncoder>()?;
112
113    // TODO: Re-enable metrics after refactoring to use sklears-metrics directly
114    // Metrics - Regression
115    // m.add_function(wrap_pyfunction!(metrics::mean_squared_error, m)?)?;
116    // m.add_function(wrap_pyfunction!(metrics::mean_absolute_error, m)?)?;
117    // m.add_function(wrap_pyfunction!(metrics::r2_score, m)?)?;
118    // m.add_function(wrap_pyfunction!(metrics::mean_squared_log_error, m)?)?;
119    // m.add_function(wrap_pyfunction!(metrics::median_absolute_error, m)?)?;
120
121    // Metrics - Classification
122    // m.add_function(wrap_pyfunction!(metrics::accuracy_score, m)?)?;
123    // m.add_function(wrap_pyfunction!(metrics::precision_score, m)?)?;
124    // m.add_function(wrap_pyfunction!(metrics::recall_score, m)?)?;
125    // m.add_function(wrap_pyfunction!(metrics::f1_score, m)?)?;
126    // m.add_function(wrap_pyfunction!(metrics::confusion_matrix, m)?)?;
127    // m.add_function(wrap_pyfunction!(metrics::classification_report, m)?)?;
128
129    // Model selection
130    m.add_function(wrap_pyfunction!(model_selection::train_test_split, m)?)?;
131    m.add_class::<model_selection::PyKFold>()?;
132
133    // Dataset functions
134    datasets::register_dataset_functions(m)?;
135
136    // Utility functions
137    m.add_function(wrap_pyfunction!(utils::get_version, m)?)?;
138    m.add_function(wrap_pyfunction!(utils::get_build_info, m)?)?;
139
140    Ok(())
141}