Skip to main content

scirs2/
datasets.rs

1//! Python bindings for scirs2-datasets
2//!
3//! This module provides Python bindings for dataset loading and generation,
4//! including toy datasets, synthetic data generators, and utilities.
5
6use pyo3::exceptions::PyRuntimeError;
7use pyo3::prelude::*;
8use pyo3::types::{PyAny, PyDict};
9
10// NumPy types for Python array interface
11use scirs2_numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayMethods};
12
13// ndarray types from scirs2-core (used indirectly via Dataset)
14#[allow(unused_imports)]
15use scirs2_core::ndarray::{Array1, Array2};
16
17// Direct imports from scirs2-datasets
18use scirs2_datasets::{
19    // Manifold datasets
20    generators::manifold::{make_s_curve, make_swiss_roll},
21    // Generators
22    generators::{
23        make_blobs, make_circles, make_classification, make_moons, make_regression, make_spirals,
24    },
25    // Toy datasets
26    toy::{load_boston, load_breast_cancer, load_diabetes, load_digits, load_iris},
27    // Utilities
28    utils::{k_fold_split, min_max_scale, normalize, train_test_split},
29    // Dataset structure
30    Dataset,
31};
32
33// ========================================
34// HELPER FUNCTION
35// ========================================
36
37/// Convert Dataset to Python dict
38fn dataset_to_pydict(py: Python, dataset: Dataset) -> PyResult<Py<PyAny>> {
39    let dict = PyDict::new(py);
40    let n_samples = dataset.data.nrows();
41    let n_features = dataset.data.ncols();
42    dict.set_item("data", dataset.data.into_pyarray(py).unbind())?;
43
44    if let Some(target) = dataset.target {
45        dict.set_item("target", target.into_pyarray(py).unbind())?;
46    }
47
48    if let Some(featurenames) = dataset.featurenames {
49        dict.set_item("feature_names", featurenames)?;
50    }
51
52    if let Some(targetnames) = dataset.targetnames {
53        dict.set_item("target_names", targetnames)?;
54    }
55
56    if let Some(description) = dataset.description {
57        dict.set_item("description", description)?;
58    }
59
60    dict.set_item("n_samples", n_samples)?;
61    dict.set_item("n_features", n_features)?;
62
63    Ok(dict.into())
64}
65
66// ========================================
67// TOY DATASETS
68// ========================================
69
70/// Load Iris dataset
71#[pyfunction]
72fn load_iris_py(py: Python) -> PyResult<Py<PyAny>> {
73    let dataset =
74        load_iris().map_err(|e| PyRuntimeError::new_err(format!("Failed to load iris: {}", e)))?;
75    dataset_to_pydict(py, dataset)
76}
77
78/// Load Boston housing dataset
79#[pyfunction]
80fn load_boston_py(py: Python) -> PyResult<Py<PyAny>> {
81    let dataset = load_boston()
82        .map_err(|e| PyRuntimeError::new_err(format!("Failed to load boston: {}", e)))?;
83    dataset_to_pydict(py, dataset)
84}
85
86/// Load Diabetes dataset
87#[pyfunction]
88fn load_diabetes_py(py: Python) -> PyResult<Py<PyAny>> {
89    let dataset = load_diabetes()
90        .map_err(|e| PyRuntimeError::new_err(format!("Failed to load diabetes: {}", e)))?;
91    dataset_to_pydict(py, dataset)
92}
93
94/// Load Breast Cancer dataset
95#[pyfunction]
96fn load_breast_cancer_py(py: Python) -> PyResult<Py<PyAny>> {
97    let dataset = load_breast_cancer()
98        .map_err(|e| PyRuntimeError::new_err(format!("Failed to load breast cancer: {}", e)))?;
99    dataset_to_pydict(py, dataset)
100}
101
102/// Load Digits dataset
103#[pyfunction]
104fn load_digits_py(py: Python) -> PyResult<Py<PyAny>> {
105    let dataset = load_digits()
106        .map_err(|e| PyRuntimeError::new_err(format!("Failed to load digits: {}", e)))?;
107    dataset_to_pydict(py, dataset)
108}
109
110// ========================================
111// DATA GENERATORS
112// ========================================
113
114/// Generate synthetic classification dataset
115#[pyfunction]
116#[pyo3(signature = (n_samples=100, n_features=20, n_informative=2, n_redundant=2, n_clusters_per_class=2, random_seed=None))]
117fn make_classification_py(
118    py: Python,
119    n_samples: usize,
120    n_features: usize,
121    n_informative: usize,
122    n_redundant: usize,
123    n_clusters_per_class: usize,
124    random_seed: Option<u64>,
125) -> PyResult<Py<PyAny>> {
126    let dataset = make_classification(
127        n_samples,
128        n_features,
129        n_informative,
130        n_redundant,
131        n_clusters_per_class,
132        random_seed,
133    )
134    .map_err(|e| PyRuntimeError::new_err(format!("Failed to make classification: {}", e)))?;
135    dataset_to_pydict(py, dataset)
136}
137
138/// Generate synthetic regression dataset
139#[pyfunction]
140#[pyo3(signature = (n_samples=100, n_features=10, n_informative=5, noise=0.1, random_seed=None))]
141fn make_regression_py(
142    py: Python,
143    n_samples: usize,
144    n_features: usize,
145    n_informative: usize,
146    noise: f64,
147    random_seed: Option<u64>,
148) -> PyResult<Py<PyAny>> {
149    let dataset = make_regression(n_samples, n_features, n_informative, noise, random_seed)
150        .map_err(|e| PyRuntimeError::new_err(format!("Failed to make regression: {}", e)))?;
151    dataset_to_pydict(py, dataset)
152}
153
154/// Generate blob clusters
155#[pyfunction]
156#[pyo3(signature = (n_samples=100, n_features=2, n_clusters=3, std_dev=1.0, random_seed=None))]
157fn make_blobs_py(
158    py: Python,
159    n_samples: usize,
160    n_features: usize,
161    n_clusters: usize,
162    std_dev: f64,
163    random_seed: Option<u64>,
164) -> PyResult<Py<PyAny>> {
165    let dataset = make_blobs(n_samples, n_features, n_clusters, std_dev, random_seed)
166        .map_err(|e| PyRuntimeError::new_err(format!("Failed to make blobs: {}", e)))?;
167    dataset_to_pydict(py, dataset)
168}
169
170/// Generate two interleaving half circles (moons)
171#[pyfunction]
172#[pyo3(signature = (n_samples=100, noise=0.1, random_seed=None))]
173fn make_moons_py(
174    py: Python,
175    n_samples: usize,
176    noise: f64,
177    random_seed: Option<u64>,
178) -> PyResult<Py<PyAny>> {
179    let dataset = make_moons(n_samples, noise, random_seed)
180        .map_err(|e| PyRuntimeError::new_err(format!("Failed to make moons: {}", e)))?;
181    dataset_to_pydict(py, dataset)
182}
183
184/// Generate two concentric circles
185#[pyfunction]
186#[pyo3(signature = (n_samples=100, factor=0.8, noise=0.1, random_seed=None))]
187fn make_circles_py(
188    py: Python,
189    n_samples: usize,
190    factor: f64,
191    noise: f64,
192    random_seed: Option<u64>,
193) -> PyResult<Py<PyAny>> {
194    let dataset = make_circles(n_samples, factor, noise, random_seed)
195        .map_err(|e| PyRuntimeError::new_err(format!("Failed to make circles: {}", e)))?;
196    dataset_to_pydict(py, dataset)
197}
198
199/// Generate spiral clusters
200#[pyfunction]
201#[pyo3(signature = (n_samples=100, n_spirals=2, noise=0.1, random_seed=None))]
202fn make_spirals_py(
203    py: Python,
204    n_samples: usize,
205    n_spirals: usize,
206    noise: f64,
207    random_seed: Option<u64>,
208) -> PyResult<Py<PyAny>> {
209    let dataset = make_spirals(n_samples, n_spirals, noise, random_seed)
210        .map_err(|e| PyRuntimeError::new_err(format!("Failed to make spirals: {}", e)))?;
211    dataset_to_pydict(py, dataset)
212}
213
214// ========================================
215// MANIFOLD DATASETS
216// ========================================
217
218/// Generate Swiss roll dataset
219#[pyfunction]
220#[pyo3(signature = (n_samples=1000, noise=0.0, random_seed=None))]
221fn make_swiss_roll_py(
222    py: Python,
223    n_samples: usize,
224    noise: f64,
225    random_seed: Option<u64>,
226) -> PyResult<Py<PyAny>> {
227    let dataset = make_swiss_roll(n_samples, noise, random_seed)
228        .map_err(|e| PyRuntimeError::new_err(format!("Failed to make swiss roll: {}", e)))?;
229    dataset_to_pydict(py, dataset)
230}
231
232/// Generate S-curve dataset
233#[pyfunction]
234#[pyo3(signature = (n_samples=1000, noise=0.0, random_seed=None))]
235fn make_s_curve_py(
236    py: Python,
237    n_samples: usize,
238    noise: f64,
239    random_seed: Option<u64>,
240) -> PyResult<Py<PyAny>> {
241    let dataset = make_s_curve(n_samples, noise, random_seed)
242        .map_err(|e| PyRuntimeError::new_err(format!("Failed to make s-curve: {}", e)))?;
243    dataset_to_pydict(py, dataset)
244}
245
246// ========================================
247// DATA UTILITIES
248// ========================================
249
250/// Split arrays into train and test subsets
251#[pyfunction]
252#[pyo3(signature = (x, y, test_size=0.25, random_seed=None))]
253fn train_test_split_py(
254    py: Python,
255    x: &Bound<'_, PyArray2<f64>>,
256    y: &Bound<'_, PyArray1<f64>>,
257    test_size: f64,
258    random_seed: Option<u64>,
259) -> PyResult<Py<PyAny>> {
260    let x_binding = x.readonly();
261    let y_binding = y.readonly();
262    let x_data = x_binding.as_array().to_owned();
263    let y_data = y_binding.as_array().to_owned();
264
265    let dataset = Dataset {
266        data: x_data,
267        target: Some(y_data),
268        featurenames: None,
269        targetnames: None,
270        feature_descriptions: None,
271        description: None,
272        metadata: std::collections::HashMap::new(),
273    };
274
275    let (train, test) = train_test_split(&dataset, test_size, random_seed)
276        .map_err(|e| PyRuntimeError::new_err(format!("Train-test split failed: {}", e)))?;
277
278    let dict = PyDict::new(py);
279    dict.set_item("x_train", train.data.into_pyarray(py).unbind())?;
280    dict.set_item("x_test", test.data.into_pyarray(py).unbind())?;
281
282    if let Some(y_train) = train.target {
283        dict.set_item("y_train", y_train.into_pyarray(py).unbind())?;
284    }
285    if let Some(y_test) = test.target {
286        dict.set_item("y_test", y_test.into_pyarray(py).unbind())?;
287    }
288
289    Ok(dict.into())
290}
291
292/// K-fold cross-validation split indices
293#[pyfunction]
294#[pyo3(signature = (n_samples, n_folds=5, shuffle=true, random_seed=None))]
295fn k_fold_split_py(
296    py: Python,
297    n_samples: usize,
298    n_folds: usize,
299    shuffle: bool,
300    random_seed: Option<u64>,
301) -> PyResult<Py<PyAny>> {
302    let folds = k_fold_split(n_samples, n_folds, shuffle, random_seed)
303        .map_err(|e| PyRuntimeError::new_err(format!("K-fold split failed: {}", e)))?;
304
305    let dict = PyDict::new(py);
306    dict.set_item("n_folds", n_folds)?;
307
308    // folds is already Vec<(Vec<usize>, Vec<usize>)>
309    dict.set_item("folds", folds)?;
310
311    Ok(dict.into())
312}
313
314/// Min-max scale array to [0, 1] range
315#[pyfunction]
316#[pyo3(signature = (data, min_val=0.0, max_val=1.0))]
317fn min_max_scale_py(
318    py: Python,
319    data: &Bound<'_, PyArray2<f64>>,
320    min_val: f64,
321    max_val: f64,
322) -> PyResult<Py<PyArray2<f64>>> {
323    let binding = data.readonly();
324    let mut arr = binding.as_array().to_owned();
325
326    min_max_scale(&mut arr, (min_val, max_val));
327
328    Ok(arr.into_pyarray(py).unbind())
329}
330
331/// Normalize array (L2 norm per row)
332#[pyfunction]
333fn normalize_py(py: Python, data: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyArray2<f64>>> {
334    let binding = data.readonly();
335    let mut arr = binding.as_array().to_owned();
336
337    normalize(&mut arr);
338
339    Ok(arr.into_pyarray(py).unbind())
340}
341
342/// Python module registration
343pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
344    // Toy datasets
345    m.add_function(wrap_pyfunction!(load_iris_py, m)?)?;
346    m.add_function(wrap_pyfunction!(load_boston_py, m)?)?;
347    m.add_function(wrap_pyfunction!(load_diabetes_py, m)?)?;
348    m.add_function(wrap_pyfunction!(load_breast_cancer_py, m)?)?;
349    m.add_function(wrap_pyfunction!(load_digits_py, m)?)?;
350
351    // Data generators
352    m.add_function(wrap_pyfunction!(make_classification_py, m)?)?;
353    m.add_function(wrap_pyfunction!(make_regression_py, m)?)?;
354    m.add_function(wrap_pyfunction!(make_blobs_py, m)?)?;
355    m.add_function(wrap_pyfunction!(make_moons_py, m)?)?;
356    m.add_function(wrap_pyfunction!(make_circles_py, m)?)?;
357    m.add_function(wrap_pyfunction!(make_spirals_py, m)?)?;
358
359    // Manifold datasets
360    m.add_function(wrap_pyfunction!(make_swiss_roll_py, m)?)?;
361    m.add_function(wrap_pyfunction!(make_s_curve_py, m)?)?;
362
363    // Utilities
364    m.add_function(wrap_pyfunction!(train_test_split_py, m)?)?;
365    m.add_function(wrap_pyfunction!(k_fold_split_py, m)?)?;
366    m.add_function(wrap_pyfunction!(min_max_scale_py, m)?)?;
367    m.add_function(wrap_pyfunction!(normalize_py, m)?)?;
368
369    Ok(())
370}