Skip to main content

sklears_python/
lib.rs

1//! Python bindings for the sklears machine learning library
2//!
3//! This crate provides PyO3-based Python bindings for sklears, enabling
4//! seamless integration with the Python ecosystem while maintaining
5//! Rust's performance advantages.
6//!
7//! # Features
8//!
9//! - Drop-in replacement for scikit-learn's most common algorithms
10//! - Pure Rust implementation with ongoing performance optimization
11//! - Full NumPy array compatibility
12//! - Comprehensive error handling with Python exceptions
13//! - Memory-safe operations with automatic reference counting
14//!
15//! # Example
16//!
17//! ```python
18//! import sklears_python as skl
19//! import numpy as np
20//!
21//! # Create sample data
22//! X = np.random.randn(100, 4)
23//! y = np.random.randn(100)
24//!
25//! # Train a linear regression model
26//! model = skl.LinearRegression()
27//! model.fit(X, y)
28//! predictions = model.predict(X)
29//! ```
30
31#[allow(unused_imports)]
32use pyo3::prelude::*;
33
34// Import modules
35mod clustering;
36mod datasets;
37mod ensemble;
38mod linear;
39// mod metrics; // TODO: Needs refactoring to use sklears-metrics directly
40mod model_selection;
41mod naive_bayes;
42mod neural_network;
43// mod preprocessing; // Temporarily disabled to test ensemble
44mod tree;
45mod utils;
46
47// Re-export main classes
48pub use clustering::*;
49pub use ensemble::*;
50pub use linear::*;
51// pub use metrics::*; // TODO: Needs refactoring
52pub use model_selection::*;
53pub use naive_bayes::*;
54pub use neural_network::*;
55// pub use preprocessing::*; // Temporarily disabled to test ensemble
56pub use tree::*;
57pub use utils::*;
58
59/// Python module for sklears machine learning library
60#[pymodule]
61fn _sklears(m: &Bound<'_, PyModule>) -> PyResult<()> {
62    // Set module metadata
63    m.add("__version__", "0.1.0")?;
64    m.add(
65        "__doc__",
66        "High-performance machine learning library with scikit-learn compatibility",
67    )?;
68
69    // Linear models
70    m.add_class::<linear::PyLinearRegression>()?;
71    m.add_class::<linear::PyRidge>()?;
72    m.add_class::<linear::PyLasso>()?;
73    m.add_class::<linear::PyElasticNet>()?;
74    m.add_class::<linear::PyBayesianRidge>()?;
75    m.add_class::<linear::PyARDRegression>()?;
76    m.add_class::<linear::PyLogisticRegression>()?;
77
78    // Ensemble methods
79    m.add_class::<ensemble::PyGradientBoostingClassifier>()?;
80    m.add_class::<ensemble::PyGradientBoostingRegressor>()?;
81    m.add_class::<ensemble::PyAdaBoostClassifier>()?;
82    m.add_class::<ensemble::PyVotingClassifier>()?;
83    m.add_class::<ensemble::PyBaggingClassifier>()?;
84
85    // Neural networks
86    m.add_class::<neural_network::PyMLPClassifier>()?;
87    m.add_class::<neural_network::PyMLPRegressor>()?;
88
89    // Tree-based models - Temporarily disabled to test ensemble
90    // m.add_class::<tree::PyDecisionTreeClassifier>()?;
91    // m.add_class::<tree::PyDecisionTreeRegressor>()?;
92    // m.add_class::<tree::PyRandomForestClassifier>()?;
93    // m.add_class::<tree::PyRandomForestRegressor>()?;
94
95    // Naive Bayes
96    m.add_class::<naive_bayes::PyGaussianNB>()?;
97    m.add_class::<naive_bayes::PyMultinomialNB>()?;
98    m.add_class::<naive_bayes::PyBernoulliNB>()?;
99    m.add_class::<naive_bayes::PyComplementNB>()?;
100
101    // Clustering
102    m.add_class::<clustering::PyKMeans>()?;
103    m.add_class::<clustering::PyDBSCAN>()?;
104
105    // Preprocessing - Temporarily disabled to test ensemble
106    // m.add_class::<preprocessing::PyStandardScaler>()?;
107    // m.add_class::<preprocessing::PyMinMaxScaler>()?;
108    // m.add_class::<preprocessing::PyLabelEncoder>()?;
109
110    // TODO: Re-enable metrics after refactoring to use sklears-metrics directly
111    // Metrics - Regression
112    // m.add_function(wrap_pyfunction!(metrics::mean_squared_error, m)?)?;
113    // m.add_function(wrap_pyfunction!(metrics::mean_absolute_error, m)?)?;
114    // m.add_function(wrap_pyfunction!(metrics::r2_score, m)?)?;
115    // m.add_function(wrap_pyfunction!(metrics::mean_squared_log_error, m)?)?;
116    // m.add_function(wrap_pyfunction!(metrics::median_absolute_error, m)?)?;
117
118    // Metrics - Classification
119    // m.add_function(wrap_pyfunction!(metrics::accuracy_score, m)?)?;
120    // m.add_function(wrap_pyfunction!(metrics::precision_score, m)?)?;
121    // m.add_function(wrap_pyfunction!(metrics::recall_score, m)?)?;
122    // m.add_function(wrap_pyfunction!(metrics::f1_score, m)?)?;
123    // m.add_function(wrap_pyfunction!(metrics::confusion_matrix, m)?)?;
124    // m.add_function(wrap_pyfunction!(metrics::classification_report, m)?)?;
125
126    // Model selection
127    m.add_function(wrap_pyfunction!(model_selection::train_test_split, m)?)?;
128    m.add_class::<model_selection::PyKFold>()?;
129
130    // Dataset functions
131    datasets::register_dataset_functions(m)?;
132
133    // Utility functions
134    m.add_function(wrap_pyfunction!(utils::get_version, m)?)?;
135    m.add_function(wrap_pyfunction!(utils::get_build_info, m)?)?;
136
137    Ok(())
138}