sklears_python/lib.rs
1//! Python bindings for the sklears machine learning library
2//!
3//! This crate provides PyO3-based Python bindings for sklears, enabling
4//! seamless integration with the Python ecosystem while maintaining
5//! Rust's performance advantages.
6//!
7//! # Features
8//!
9//! - Drop-in replacement for scikit-learn's most common algorithms
10//! - Pure Rust implementation with ongoing performance optimization
11//! - Full NumPy array compatibility
12//! - Comprehensive error handling with Python exceptions
13//! - Memory-safe operations with automatic reference counting
14//!
15//! # Example
16//!
17//! ```python
18//! import sklears_python as skl
19//! import numpy as np
20//!
21//! # Create sample data
22//! X = np.random.randn(100, 4)
23//! y = np.random.randn(100)
24//!
25//! # Train a linear regression model
26//! model = skl.LinearRegression()
27//! model.fit(X, y)
28//! predictions = model.predict(X)
29//! ```
30
31#[allow(unused_imports)]
32use pyo3::prelude::*;
33
34// Import modules
35mod clustering;
36mod datasets;
37mod ensemble;
38mod linear;
39// mod metrics; // TODO: Needs refactoring to use sklears-metrics directly
40mod model_selection;
41mod naive_bayes;
42mod neural_network;
43// mod preprocessing; // Temporarily disabled to test ensemble
44mod tree;
45mod utils;
46
47// Re-export main classes
48pub use clustering::*;
49pub use ensemble::*;
50pub use linear::*;
51// pub use metrics::*; // TODO: Needs refactoring
52pub use model_selection::*;
53pub use naive_bayes::*;
54pub use neural_network::*;
55// pub use preprocessing::*; // Temporarily disabled to test ensemble
56pub use tree::*;
57pub use utils::*;
58
59/// Python module for sklears machine learning library
60#[pymodule]
61fn _sklears(m: &Bound<'_, PyModule>) -> PyResult<()> {
62 // Set module metadata
63 m.add("__version__", "0.1.0")?;
64 m.add(
65 "__doc__",
66 "High-performance machine learning library with scikit-learn compatibility",
67 )?;
68
69 // Linear models
70 m.add_class::<linear::PyLinearRegression>()?;
71 m.add_class::<linear::PyRidge>()?;
72 m.add_class::<linear::PyLasso>()?;
73 m.add_class::<linear::PyElasticNet>()?;
74 m.add_class::<linear::PyBayesianRidge>()?;
75 m.add_class::<linear::PyARDRegression>()?;
76 m.add_class::<linear::PyLogisticRegression>()?;
77
78 // Ensemble methods
79 m.add_class::<ensemble::PyGradientBoostingClassifier>()?;
80 m.add_class::<ensemble::PyGradientBoostingRegressor>()?;
81 m.add_class::<ensemble::PyAdaBoostClassifier>()?;
82 m.add_class::<ensemble::PyVotingClassifier>()?;
83 m.add_class::<ensemble::PyBaggingClassifier>()?;
84
85 // Neural networks
86 m.add_class::<neural_network::PyMLPClassifier>()?;
87 m.add_class::<neural_network::PyMLPRegressor>()?;
88
89 // Tree-based models - Temporarily disabled to test ensemble
90 // m.add_class::<tree::PyDecisionTreeClassifier>()?;
91 // m.add_class::<tree::PyDecisionTreeRegressor>()?;
92 // m.add_class::<tree::PyRandomForestClassifier>()?;
93 // m.add_class::<tree::PyRandomForestRegressor>()?;
94
95 // Naive Bayes
96 m.add_class::<naive_bayes::PyGaussianNB>()?;
97 m.add_class::<naive_bayes::PyMultinomialNB>()?;
98 m.add_class::<naive_bayes::PyBernoulliNB>()?;
99 m.add_class::<naive_bayes::PyComplementNB>()?;
100
101 // Clustering
102 m.add_class::<clustering::PyKMeans>()?;
103 m.add_class::<clustering::PyDBSCAN>()?;
104
105 // Preprocessing - Temporarily disabled to test ensemble
106 // m.add_class::<preprocessing::PyStandardScaler>()?;
107 // m.add_class::<preprocessing::PyMinMaxScaler>()?;
108 // m.add_class::<preprocessing::PyLabelEncoder>()?;
109
110 // TODO: Re-enable metrics after refactoring to use sklears-metrics directly
111 // Metrics - Regression
112 // m.add_function(wrap_pyfunction!(metrics::mean_squared_error, m)?)?;
113 // m.add_function(wrap_pyfunction!(metrics::mean_absolute_error, m)?)?;
114 // m.add_function(wrap_pyfunction!(metrics::r2_score, m)?)?;
115 // m.add_function(wrap_pyfunction!(metrics::mean_squared_log_error, m)?)?;
116 // m.add_function(wrap_pyfunction!(metrics::median_absolute_error, m)?)?;
117
118 // Metrics - Classification
119 // m.add_function(wrap_pyfunction!(metrics::accuracy_score, m)?)?;
120 // m.add_function(wrap_pyfunction!(metrics::precision_score, m)?)?;
121 // m.add_function(wrap_pyfunction!(metrics::recall_score, m)?)?;
122 // m.add_function(wrap_pyfunction!(metrics::f1_score, m)?)?;
123 // m.add_function(wrap_pyfunction!(metrics::confusion_matrix, m)?)?;
124 // m.add_function(wrap_pyfunction!(metrics::classification_report, m)?)?;
125
126 // Model selection
127 m.add_function(wrap_pyfunction!(model_selection::train_test_split, m)?)?;
128 m.add_class::<model_selection::PyKFold>()?;
129
130 // Dataset functions
131 datasets::register_dataset_functions(m)?;
132
133 // Utility functions
134 m.add_function(wrap_pyfunction!(utils::get_version, m)?)?;
135 m.add_function(wrap_pyfunction!(utils::get_build_info, m)?)?;
136
137 Ok(())
138}