1use pyo3::exceptions::PyRuntimeError;
7use pyo3::prelude::*;
8use pyo3::types::{PyAny, PyDict};
9
10use scirs2_numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayMethods};
12
13#[allow(unused_imports)]
15use scirs2_core::ndarray::{Array1, Array2};
16
17use scirs2_datasets::{
19 generators::manifold::{make_s_curve, make_swiss_roll},
21 generators::{
23 make_blobs, make_circles, make_classification, make_moons, make_regression, make_spirals,
24 },
25 toy::{load_boston, load_breast_cancer, load_diabetes, load_digits, load_iris},
27 utils::{k_fold_split, min_max_scale, normalize, train_test_split},
29 Dataset,
31};
32
33fn dataset_to_pydict(py: Python, dataset: Dataset) -> PyResult<Py<PyAny>> {
39 let dict = PyDict::new(py);
40 let n_samples = dataset.data.nrows();
41 let n_features = dataset.data.ncols();
42 dict.set_item("data", dataset.data.into_pyarray(py).unbind())?;
43
44 if let Some(target) = dataset.target {
45 dict.set_item("target", target.into_pyarray(py).unbind())?;
46 }
47
48 if let Some(featurenames) = dataset.featurenames {
49 dict.set_item("feature_names", featurenames)?;
50 }
51
52 if let Some(targetnames) = dataset.targetnames {
53 dict.set_item("target_names", targetnames)?;
54 }
55
56 if let Some(description) = dataset.description {
57 dict.set_item("description", description)?;
58 }
59
60 dict.set_item("n_samples", n_samples)?;
61 dict.set_item("n_features", n_features)?;
62
63 Ok(dict.into())
64}
65
66#[pyfunction]
72fn load_iris_py(py: Python) -> PyResult<Py<PyAny>> {
73 let dataset =
74 load_iris().map_err(|e| PyRuntimeError::new_err(format!("Failed to load iris: {}", e)))?;
75 dataset_to_pydict(py, dataset)
76}
77
78#[pyfunction]
80fn load_boston_py(py: Python) -> PyResult<Py<PyAny>> {
81 let dataset = load_boston()
82 .map_err(|e| PyRuntimeError::new_err(format!("Failed to load boston: {}", e)))?;
83 dataset_to_pydict(py, dataset)
84}
85
86#[pyfunction]
88fn load_diabetes_py(py: Python) -> PyResult<Py<PyAny>> {
89 let dataset = load_diabetes()
90 .map_err(|e| PyRuntimeError::new_err(format!("Failed to load diabetes: {}", e)))?;
91 dataset_to_pydict(py, dataset)
92}
93
94#[pyfunction]
96fn load_breast_cancer_py(py: Python) -> PyResult<Py<PyAny>> {
97 let dataset = load_breast_cancer()
98 .map_err(|e| PyRuntimeError::new_err(format!("Failed to load breast cancer: {}", e)))?;
99 dataset_to_pydict(py, dataset)
100}
101
102#[pyfunction]
104fn load_digits_py(py: Python) -> PyResult<Py<PyAny>> {
105 let dataset = load_digits()
106 .map_err(|e| PyRuntimeError::new_err(format!("Failed to load digits: {}", e)))?;
107 dataset_to_pydict(py, dataset)
108}
109
110#[pyfunction]
116#[pyo3(signature = (n_samples=100, n_features=20, n_informative=2, n_redundant=2, n_clusters_per_class=2, random_seed=None))]
117fn make_classification_py(
118 py: Python,
119 n_samples: usize,
120 n_features: usize,
121 n_informative: usize,
122 n_redundant: usize,
123 n_clusters_per_class: usize,
124 random_seed: Option<u64>,
125) -> PyResult<Py<PyAny>> {
126 let dataset = make_classification(
127 n_samples,
128 n_features,
129 n_informative,
130 n_redundant,
131 n_clusters_per_class,
132 random_seed,
133 )
134 .map_err(|e| PyRuntimeError::new_err(format!("Failed to make classification: {}", e)))?;
135 dataset_to_pydict(py, dataset)
136}
137
138#[pyfunction]
140#[pyo3(signature = (n_samples=100, n_features=10, n_informative=5, noise=0.1, random_seed=None))]
141fn make_regression_py(
142 py: Python,
143 n_samples: usize,
144 n_features: usize,
145 n_informative: usize,
146 noise: f64,
147 random_seed: Option<u64>,
148) -> PyResult<Py<PyAny>> {
149 let dataset = make_regression(n_samples, n_features, n_informative, noise, random_seed)
150 .map_err(|e| PyRuntimeError::new_err(format!("Failed to make regression: {}", e)))?;
151 dataset_to_pydict(py, dataset)
152}
153
154#[pyfunction]
156#[pyo3(signature = (n_samples=100, n_features=2, n_clusters=3, std_dev=1.0, random_seed=None))]
157fn make_blobs_py(
158 py: Python,
159 n_samples: usize,
160 n_features: usize,
161 n_clusters: usize,
162 std_dev: f64,
163 random_seed: Option<u64>,
164) -> PyResult<Py<PyAny>> {
165 let dataset = make_blobs(n_samples, n_features, n_clusters, std_dev, random_seed)
166 .map_err(|e| PyRuntimeError::new_err(format!("Failed to make blobs: {}", e)))?;
167 dataset_to_pydict(py, dataset)
168}
169
170#[pyfunction]
172#[pyo3(signature = (n_samples=100, noise=0.1, random_seed=None))]
173fn make_moons_py(
174 py: Python,
175 n_samples: usize,
176 noise: f64,
177 random_seed: Option<u64>,
178) -> PyResult<Py<PyAny>> {
179 let dataset = make_moons(n_samples, noise, random_seed)
180 .map_err(|e| PyRuntimeError::new_err(format!("Failed to make moons: {}", e)))?;
181 dataset_to_pydict(py, dataset)
182}
183
184#[pyfunction]
186#[pyo3(signature = (n_samples=100, factor=0.8, noise=0.1, random_seed=None))]
187fn make_circles_py(
188 py: Python,
189 n_samples: usize,
190 factor: f64,
191 noise: f64,
192 random_seed: Option<u64>,
193) -> PyResult<Py<PyAny>> {
194 let dataset = make_circles(n_samples, factor, noise, random_seed)
195 .map_err(|e| PyRuntimeError::new_err(format!("Failed to make circles: {}", e)))?;
196 dataset_to_pydict(py, dataset)
197}
198
199#[pyfunction]
201#[pyo3(signature = (n_samples=100, n_spirals=2, noise=0.1, random_seed=None))]
202fn make_spirals_py(
203 py: Python,
204 n_samples: usize,
205 n_spirals: usize,
206 noise: f64,
207 random_seed: Option<u64>,
208) -> PyResult<Py<PyAny>> {
209 let dataset = make_spirals(n_samples, n_spirals, noise, random_seed)
210 .map_err(|e| PyRuntimeError::new_err(format!("Failed to make spirals: {}", e)))?;
211 dataset_to_pydict(py, dataset)
212}
213
214#[pyfunction]
220#[pyo3(signature = (n_samples=1000, noise=0.0, random_seed=None))]
221fn make_swiss_roll_py(
222 py: Python,
223 n_samples: usize,
224 noise: f64,
225 random_seed: Option<u64>,
226) -> PyResult<Py<PyAny>> {
227 let dataset = make_swiss_roll(n_samples, noise, random_seed)
228 .map_err(|e| PyRuntimeError::new_err(format!("Failed to make swiss roll: {}", e)))?;
229 dataset_to_pydict(py, dataset)
230}
231
232#[pyfunction]
234#[pyo3(signature = (n_samples=1000, noise=0.0, random_seed=None))]
235fn make_s_curve_py(
236 py: Python,
237 n_samples: usize,
238 noise: f64,
239 random_seed: Option<u64>,
240) -> PyResult<Py<PyAny>> {
241 let dataset = make_s_curve(n_samples, noise, random_seed)
242 .map_err(|e| PyRuntimeError::new_err(format!("Failed to make s-curve: {}", e)))?;
243 dataset_to_pydict(py, dataset)
244}
245
246#[pyfunction]
252#[pyo3(signature = (x, y, test_size=0.25, random_seed=None))]
253fn train_test_split_py(
254 py: Python,
255 x: &Bound<'_, PyArray2<f64>>,
256 y: &Bound<'_, PyArray1<f64>>,
257 test_size: f64,
258 random_seed: Option<u64>,
259) -> PyResult<Py<PyAny>> {
260 let x_binding = x.readonly();
261 let y_binding = y.readonly();
262 let x_data = x_binding.as_array().to_owned();
263 let y_data = y_binding.as_array().to_owned();
264
265 let dataset = Dataset {
266 data: x_data,
267 target: Some(y_data),
268 featurenames: None,
269 targetnames: None,
270 feature_descriptions: None,
271 description: None,
272 metadata: std::collections::HashMap::new(),
273 };
274
275 let (train, test) = train_test_split(&dataset, test_size, random_seed)
276 .map_err(|e| PyRuntimeError::new_err(format!("Train-test split failed: {}", e)))?;
277
278 let dict = PyDict::new(py);
279 dict.set_item("x_train", train.data.into_pyarray(py).unbind())?;
280 dict.set_item("x_test", test.data.into_pyarray(py).unbind())?;
281
282 if let Some(y_train) = train.target {
283 dict.set_item("y_train", y_train.into_pyarray(py).unbind())?;
284 }
285 if let Some(y_test) = test.target {
286 dict.set_item("y_test", y_test.into_pyarray(py).unbind())?;
287 }
288
289 Ok(dict.into())
290}
291
292#[pyfunction]
294#[pyo3(signature = (n_samples, n_folds=5, shuffle=true, random_seed=None))]
295fn k_fold_split_py(
296 py: Python,
297 n_samples: usize,
298 n_folds: usize,
299 shuffle: bool,
300 random_seed: Option<u64>,
301) -> PyResult<Py<PyAny>> {
302 let folds = k_fold_split(n_samples, n_folds, shuffle, random_seed)
303 .map_err(|e| PyRuntimeError::new_err(format!("K-fold split failed: {}", e)))?;
304
305 let dict = PyDict::new(py);
306 dict.set_item("n_folds", n_folds)?;
307
308 dict.set_item("folds", folds)?;
310
311 Ok(dict.into())
312}
313
314#[pyfunction]
316#[pyo3(signature = (data, min_val=0.0, max_val=1.0))]
317fn min_max_scale_py(
318 py: Python,
319 data: &Bound<'_, PyArray2<f64>>,
320 min_val: f64,
321 max_val: f64,
322) -> PyResult<Py<PyArray2<f64>>> {
323 let binding = data.readonly();
324 let mut arr = binding.as_array().to_owned();
325
326 min_max_scale(&mut arr, (min_val, max_val));
327
328 Ok(arr.into_pyarray(py).unbind())
329}
330
331#[pyfunction]
333fn normalize_py(py: Python, data: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyArray2<f64>>> {
334 let binding = data.readonly();
335 let mut arr = binding.as_array().to_owned();
336
337 normalize(&mut arr);
338
339 Ok(arr.into_pyarray(py).unbind())
340}
341
342pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
344 m.add_function(wrap_pyfunction!(load_iris_py, m)?)?;
346 m.add_function(wrap_pyfunction!(load_boston_py, m)?)?;
347 m.add_function(wrap_pyfunction!(load_diabetes_py, m)?)?;
348 m.add_function(wrap_pyfunction!(load_breast_cancer_py, m)?)?;
349 m.add_function(wrap_pyfunction!(load_digits_py, m)?)?;
350
351 m.add_function(wrap_pyfunction!(make_classification_py, m)?)?;
353 m.add_function(wrap_pyfunction!(make_regression_py, m)?)?;
354 m.add_function(wrap_pyfunction!(make_blobs_py, m)?)?;
355 m.add_function(wrap_pyfunction!(make_moons_py, m)?)?;
356 m.add_function(wrap_pyfunction!(make_circles_py, m)?)?;
357 m.add_function(wrap_pyfunction!(make_spirals_py, m)?)?;
358
359 m.add_function(wrap_pyfunction!(make_swiss_roll_py, m)?)?;
361 m.add_function(wrap_pyfunction!(make_s_curve_py, m)?)?;
362
363 m.add_function(wrap_pyfunction!(train_test_split_py, m)?)?;
365 m.add_function(wrap_pyfunction!(k_fold_split_py, m)?)?;
366 m.add_function(wrap_pyfunction!(min_max_scale_py, m)?)?;
367 m.add_function(wrap_pyfunction!(normalize_py, m)?)?;
368
369 Ok(())
370}