scirs2_datasets/
sample.rs1use crate::error::{DatasetsError, Result};
7use crate::utils::Dataset;
8
9#[cfg(feature = "download")]
10use crate::cache::download_data;
11#[cfg(feature = "download")]
12use crate::loaders;
13
14#[allow(dead_code)]
16const DATASET_BASE_URL: &str = "https://raw.githubusercontent.com/cool-japan/scirs-datasets/main/";
17
18#[cfg(feature = "download")]
20#[allow(dead_code)]
21pub fn load_california_housing(force_download: bool) -> Result<Dataset> {
22    let url = format!("{DATASET_BASE_URL}/california_housing.csv");
23
24    let data = download_data(&url, force_download)?;
26
27    use std::io::Write;
29    let tempdir = std::env::temp_dir();
30    let temppath = tempdir.join("scirs2_california_housing.csv");
31
32    let mut temp_file = std::fs::File::create(&temppath).map_err(DatasetsError::IoError)?;
33
34    temp_file.write_all(&data).map_err(DatasetsError::IoError)?;
35
36    let config = loaders::CsvConfig::new()
38        .with_header(true)
39        .with_target_column(Some(8));
40    let mut dataset = loaders::load_csv(&temppath, config)?;
41
42    let featurenames = vec![
44        "MedInc".to_string(),
45        "HouseAge".to_string(),
46        "AveRooms".to_string(),
47        "AveBedrms".to_string(),
48        "Population".to_string(),
49        "AveOccup".to_string(),
50        "Latitude".to_string(),
51        "Longitude".to_string(),
52    ];
53
54    let description = "California Housing dataset
55    
56The data was derived from the 1990 U.S. census, using one row per census block group.
57A block group is the smallest geographical unit for which the U.S. Census Bureau 
58publishes sample data.
59
60Features:
61- MedInc: median income in block group
62- HouseAge: median house age in block group
63- AveRooms: average number of rooms per household
64- AveBedrms: average number of bedrooms per household
65- Population: block group population
66- AveOccup: average number of household members
67- Latitude: block group latitude
68- Longitude: block group longitude
69
70Target: Median house value for California districts, expressed in hundreds of thousands of dollars.
71
72This dataset is useful for regression tasks."
73        .to_string();
74
75    dataset = dataset
76        .with_featurenames(featurenames)
77        .with_description(description);
78
79    std::fs::remove_file(temppath).ok();
81
82    Ok(dataset)
83}
84
85#[cfg(not(feature = "download"))]
87#[allow(dead_code)]
100pub fn load_california_housing(_forcedownload: bool) -> Result<Dataset> {
101    Err(DatasetsError::Other(
102        "Download feature is not enabled. Recompile with --features _download".to_string(),
103    ))
104}
105
106#[cfg(feature = "download")]
108#[allow(dead_code)]
109pub fn load_wine(force_download: bool) -> Result<Dataset> {
110    let url = format!("{DATASET_BASE_URL}/wine.csv");
111
112    let data = download_data(&url, force_download)?;
114
115    use std::io::Write;
117    let tempdir = std::env::temp_dir();
118    let temppath = tempdir.join("scirs2_wine.csv");
119
120    let mut temp_file = std::fs::File::create(&temppath).map_err(DatasetsError::IoError)?;
121
122    temp_file.write_all(&data).map_err(DatasetsError::IoError)?;
123
124    let mut dataset = loaders::load_csv_legacy(&temppath, true, Some(0))?;
126
127    let featurenames = vec![
129        "alcohol".to_string(),
130        "malic_acid".to_string(),
131        "ash".to_string(),
132        "alcalinity_of_ash".to_string(),
133        "magnesium".to_string(),
134        "total_phenols".to_string(),
135        "flavanoids".to_string(),
136        "nonflavanoid_phenols".to_string(),
137        "proanthocyanins".to_string(),
138        "color_intensity".to_string(),
139        "hue".to_string(),
140        "od280_od315_of_diluted_wines".to_string(),
141        "proline".to_string(),
142    ];
143
144    let targetnames = vec![
145        "class_0".to_string(),
146        "class_1".to_string(),
147        "class_2".to_string(),
148    ];
149
150    let description = "Wine Recognition dataset
151    
152The data is the results of a chemical analysis of wines grown in the same region in Italy
153but derived from three different cultivars. The analysis determined the quantities of
15413 constituents found in each of the three types of wines.
155
156Features: Various chemical properties of the wine
157
158Target: Class of wine (0, 1, or 2)
159
160This dataset is useful for classification tasks."
161        .to_string();
162
163    dataset = dataset
164        .with_featurenames(featurenames)
165        .with_targetnames(targetnames)
166        .with_description(description);
167
168    std::fs::remove_file(temppath).ok();
170
171    Ok(dataset)
172}
173
174#[cfg(not(feature = "download"))]
176#[allow(dead_code)]
189pub fn load_wine(_forcedownload: bool) -> Result<Dataset> {
190    Err(DatasetsError::Other(
191        "Download feature is not enabled. Recompile with --features _download".to_string(),
192    ))
193}
194
195#[cfg(feature = "download")]
197#[allow(dead_code)]
198pub fn get_available_datasets() -> Result<Vec<String>> {
199    let url = format!("{DATASET_BASE_URL}/datasets_index.txt");
200
201    let data = download_data(&url, true)?;
203
204    let content = String::from_utf8(data).map_err(|e| {
206        DatasetsError::InvalidFormat(format!("Failed to parse datasets index: {e}"))
207    })?;
208
209    let datasets = content
210        .lines()
211        .map(|line| line.trim().to_string())
212        .filter(|line| !line.is_empty())
213        .collect();
214
215    Ok(datasets)
216}
217
218#[cfg(not(feature = "download"))]
220#[allow(dead_code)]
229pub fn get_available_datasets() -> Result<Vec<String>> {
230    Err(DatasetsError::Other(
231        "Download feature is not enabled. Recompile with --features download".to_string(),
232    ))
233}