sklears_python/lib.rs
1#![allow(missing_docs)]
2#![allow(clippy::too_many_arguments)]
3#![allow(clippy::type_complexity)]
4//! Python bindings for the sklears machine learning library
5//!
6//! This crate provides PyO3-based Python bindings for sklears, enabling
7//! seamless integration with the Python ecosystem while maintaining
8//! Rust's performance advantages.
9//!
10//! # Features
11//!
12//! - Drop-in replacement for scikit-learn's most common algorithms
13//! - Pure Rust implementation with ongoing performance optimization
14//! - Full NumPy array compatibility
15//! - Comprehensive error handling with Python exceptions
16//! - Memory-safe operations with automatic reference counting
17//!
18//! # Example
19//!
20//! ```python
21//! import sklears_python as skl
22//! import numpy as np
23//!
24//! # Create sample data
25//! X = np.random.randn(100, 4)
26//! y = np.random.randn(100)
27//!
28//! # Train a linear regression model
29//! model = skl.LinearRegression()
30//! model.fit(X, y)
31//! predictions = model.predict(X)
32//! ```
33
34#[allow(unused_imports)]
35use pyo3::prelude::*;
36
37// Import modules
38mod clustering;
39mod datasets;
40mod ensemble;
41mod linear;
42// mod metrics; // TODO: Needs refactoring to use sklears-metrics directly
43mod model_selection;
44mod naive_bayes;
45mod neural_network;
46// mod preprocessing; // Temporarily disabled to test ensemble
47mod tree;
48mod utils;
49
50// Re-export main classes
51pub use clustering::*;
52pub use ensemble::*;
53pub use linear::*;
54// pub use metrics::*; // TODO: Needs refactoring
55pub use model_selection::*;
56pub use naive_bayes::*;
57pub use neural_network::*;
58// pub use preprocessing::*; // Temporarily disabled to test ensemble
59pub use tree::*;
60pub use utils::*;
61
62/// Python module for sklears machine learning library
63#[pymodule]
64fn sklears_python(m: &Bound<'_, PyModule>) -> PyResult<()> {
65 // Set module metadata
66 m.add("__version__", "0.1.0-beta.1")?;
67 m.add(
68 "__doc__",
69 "High-performance machine learning library with scikit-learn compatibility",
70 )?;
71
72 // Linear models
73 m.add_class::<linear::PyLinearRegression>()?;
74 m.add_class::<linear::PyRidge>()?;
75 m.add_class::<linear::PyLasso>()?;
76 m.add_class::<linear::PyElasticNet>()?;
77 m.add_class::<linear::PyBayesianRidge>()?;
78 m.add_class::<linear::PyARDRegression>()?;
79 m.add_class::<linear::PyLogisticRegression>()?;
80
81 // Ensemble methods
82 m.add_class::<ensemble::PyGradientBoostingClassifier>()?;
83 m.add_class::<ensemble::PyGradientBoostingRegressor>()?;
84 m.add_class::<ensemble::PyAdaBoostClassifier>()?;
85 m.add_class::<ensemble::PyVotingClassifier>()?;
86 m.add_class::<ensemble::PyBaggingClassifier>()?;
87
88 // Neural networks
89 m.add_class::<neural_network::PyMLPClassifier>()?;
90 m.add_class::<neural_network::PyMLPRegressor>()?;
91
92 // Tree-based models - Temporarily disabled to test ensemble
93 // m.add_class::<tree::PyDecisionTreeClassifier>()?;
94 // m.add_class::<tree::PyDecisionTreeRegressor>()?;
95 // m.add_class::<tree::PyRandomForestClassifier>()?;
96 // m.add_class::<tree::PyRandomForestRegressor>()?;
97
98 // Naive Bayes
99 m.add_class::<naive_bayes::PyGaussianNB>()?;
100 m.add_class::<naive_bayes::PyMultinomialNB>()?;
101 m.add_class::<naive_bayes::PyBernoulliNB>()?;
102 m.add_class::<naive_bayes::PyComplementNB>()?;
103
104 // Clustering
105 m.add_class::<clustering::PyKMeans>()?;
106 m.add_class::<clustering::PyDBSCAN>()?;
107
108 // Preprocessing - Temporarily disabled to test ensemble
109 // m.add_class::<preprocessing::PyStandardScaler>()?;
110 // m.add_class::<preprocessing::PyMinMaxScaler>()?;
111 // m.add_class::<preprocessing::PyLabelEncoder>()?;
112
113 // TODO: Re-enable metrics after refactoring to use sklears-metrics directly
114 // Metrics - Regression
115 // m.add_function(wrap_pyfunction!(metrics::mean_squared_error, m)?)?;
116 // m.add_function(wrap_pyfunction!(metrics::mean_absolute_error, m)?)?;
117 // m.add_function(wrap_pyfunction!(metrics::r2_score, m)?)?;
118 // m.add_function(wrap_pyfunction!(metrics::mean_squared_log_error, m)?)?;
119 // m.add_function(wrap_pyfunction!(metrics::median_absolute_error, m)?)?;
120
121 // Metrics - Classification
122 // m.add_function(wrap_pyfunction!(metrics::accuracy_score, m)?)?;
123 // m.add_function(wrap_pyfunction!(metrics::precision_score, m)?)?;
124 // m.add_function(wrap_pyfunction!(metrics::recall_score, m)?)?;
125 // m.add_function(wrap_pyfunction!(metrics::f1_score, m)?)?;
126 // m.add_function(wrap_pyfunction!(metrics::confusion_matrix, m)?)?;
127 // m.add_function(wrap_pyfunction!(metrics::classification_report, m)?)?;
128
129 // Model selection
130 m.add_function(wrap_pyfunction!(model_selection::train_test_split, m)?)?;
131 m.add_class::<model_selection::PyKFold>()?;
132
133 // Dataset functions
134 datasets::register_dataset_functions(m)?;
135
136 // Utility functions
137 m.add_function(wrap_pyfunction!(utils::get_version, m)?)?;
138 m.add_function(wrap_pyfunction!(utils::get_build_info, m)?)?;
139
140 Ok(())
141}