1use crate::cache::RegistryEntry;
4use crate::error::{DatasetsError, Result};
5use std::collections::HashMap;
6
7#[derive(Debug, Clone, Default)]
9pub struct DatasetMetadata {
10 pub name: String,
12 pub description: String,
14 pub n_samples: usize,
16 pub n_features: usize,
18 pub task_type: String,
20 pub targetnames: Option<Vec<String>>,
22 pub featurenames: Option<Vec<String>>,
24 pub url: Option<String>,
26 pub checksum: Option<String>,
28}
29
30pub struct DatasetRegistry {
32 entries: HashMap<String, RegistryEntry>,
34}
35
36impl Default for DatasetRegistry {
37 fn default() -> Self {
38 let mut registry = Self::new();
39 registry.populate_default_datasets();
40 registry
41 }
42}
43
44impl DatasetRegistry {
45 pub fn new() -> Self {
47 Self {
48 entries: HashMap::new(),
49 }
50 }
51
52 pub fn register(&mut self, name: String, entry: RegistryEntry) {
54 self.entries.insert(name, entry);
55 }
56
57 pub fn get(&self, name: &str) -> Option<&RegistryEntry> {
59 self.entries.get(name)
60 }
61
62 pub fn list_datasets(&self) -> Vec<String> {
64 self.entries.keys().cloned().collect()
65 }
66
67 pub fn contains(&self, name: &str) -> bool {
69 self.entries.contains_key(name)
70 }
71
72 pub fn get_metadata(&self, name: &str) -> Result<DatasetMetadata> {
74 match name {
75 "iris" => Ok(DatasetMetadata {
76 name: "Iris".to_string(),
77 description: "Classic iris flower dataset for classification".to_string(),
78 n_samples: 150,
79 n_features: 4,
80 task_type: "classification".to_string(),
81 targetnames: Some(vec![
82 "setosa".to_string(),
83 "versicolor".to_string(),
84 "virginica".to_string(),
85 ]),
86 featurenames: Some(vec![
87 "sepal_length".to_string(),
88 "sepal_width".to_string(),
89 "petal_length".to_string(),
90 "petal_width".to_string(),
91 ]),
92 url: None,
93 checksum: None,
94 }),
95 "boston" => Ok(DatasetMetadata {
96 name: "Boston Housing".to_string(),
97 description: "Boston housing prices dataset for regression".to_string(),
98 n_samples: 506,
99 n_features: 13,
100 task_type: "regression".to_string(),
101 targetnames: None,
102 featurenames: None,
103 url: None,
104 checksum: None,
105 }),
106 "digits" => Ok(DatasetMetadata {
107 name: "Digits".to_string(),
108 description: "Hand-written digits dataset for image classification".to_string(),
109 n_samples: 1797,
110 n_features: 64,
111 task_type: "classification".to_string(),
112 targetnames: Some(vec![
113 "0".to_string(),
114 "1".to_string(),
115 "2".to_string(),
116 "3".to_string(),
117 "4".to_string(),
118 "5".to_string(),
119 "6".to_string(),
120 "7".to_string(),
121 "8".to_string(),
122 "9".to_string(),
123 ]),
124 featurenames: None,
125 url: None,
126 checksum: None,
127 }),
128 "wine" => Ok(DatasetMetadata {
129 name: "Wine".to_string(),
130 description: "Wine recognition dataset for classification".to_string(),
131 n_samples: 178,
132 n_features: 13,
133 task_type: "classification".to_string(),
134 targetnames: Some(vec![
135 "class_0".to_string(),
136 "class_1".to_string(),
137 "class_2".to_string(),
138 ]),
139 featurenames: None,
140 url: None,
141 checksum: None,
142 }),
143 "breast_cancer" => Ok(DatasetMetadata {
144 name: "Breast Cancer".to_string(),
145 description: "Breast cancer wisconsin dataset for classification".to_string(),
146 n_samples: 569,
147 n_features: 30,
148 task_type: "classification".to_string(),
149 targetnames: Some(vec!["malignant".to_string(), "benign".to_string()]),
150 featurenames: None,
151 url: None,
152 checksum: None,
153 }),
154 "diabetes" => Ok(DatasetMetadata {
155 name: "Diabetes".to_string(),
156 description: "Diabetes dataset for regression".to_string(),
157 n_samples: 442,
158 n_features: 10,
159 task_type: "regression".to_string(),
160 targetnames: None,
161 featurenames: None,
162 url: None,
163 checksum: None,
164 }),
165 _ => Err(DatasetsError::Other(format!("Unknown dataset: {name}"))),
166 }
167 }
168
169 fn populate_default_datasets(&mut self) {
174 self.register(
176 "example".to_string(),
177 RegistryEntry {
178 url: "file://data/example.csv",
179 sha256: "c51c3ff2e8a5db28b1baed809a2ba29f29643e5a26ad476448eb3889996173d6",
180 },
181 );
182
183 self.register(
184 "sample_data".to_string(),
185 RegistryEntry {
186 url: "file://examples/sample_data.csv",
187 sha256: "59cceb2c80692ee2c1c3b607335d1feb983ceed24214d1ffc2eace9f3ce5ab47",
188 },
189 );
190
191 self.register_toy_dataset("iris", "Classic iris flower dataset for classification");
193 self.register_toy_dataset("boston", "Boston housing prices dataset for regression");
194 self.register_toy_dataset(
195 "digits",
196 "Hand-written digits dataset for image classification",
197 );
198 self.register_toy_dataset("wine", "Wine recognition dataset for classification");
199 self.register_toy_dataset(
200 "breast_cancer",
201 "Breast cancer wisconsin dataset for classification",
202 );
203 self.register_toy_dataset("diabetes", "Diabetes dataset for regression");
204
205 }
240
241 fn register_toy_dataset(&mut self, name: &str, _description: &str) {
243 let url = match name {
244 "iris" => "builtin://iris",
245 "boston" => "builtin://boston",
246 "digits" => "builtin://digits",
247 "wine" => "builtin://wine",
248 "breast_cancer" => "builtin://breast_cancer",
249 "diabetes" => "builtin://diabetes",
250 _ => "builtin://unknown",
251 };
252
253 self.register(
254 name.to_string(),
255 RegistryEntry {
256 url,
257 sha256: "builtin", },
259 );
260 }
261}
262
263#[allow(dead_code)]
265pub fn get_registry() -> DatasetRegistry {
266 DatasetRegistry::default()
267}
268
269#[cfg(feature = "download")]
271#[allow(dead_code)]
272pub fn load_dataset_byname(name: &str, forcedownload: bool) -> Result<crate::utils::Dataset> {
273 let registry = get_registry();
274
275 if let Some(entry) = registry.get(name) {
276 if entry.url.starts_with("builtin://") {
278 match name {
280 "iris" => crate::toy::load_iris(),
281 "boston" => crate::toy::load_boston(),
282 "digits" => crate::toy::load_digits(),
283 "wine" => crate::sample::load_wine(false),
284 "breast_cancer" => crate::toy::load_breast_cancer(),
285 "diabetes" => crate::toy::load_diabetes(),
286 _ => Err(DatasetsError::Other(format!(
287 "Built-in dataset '{}' not implemented",
288 name
289 ))),
290 }
291 } else if entry.url.starts_with("file://") {
292 load_local_dataset(name, &entry.url[7..], entry.sha256) } else if entry.url.starts_with("http") {
295 match name {
297 "california_housing" => crate::sample::load_california_housing(forcedownload),
298 "electrocardiogram" => crate::time_series::electrocardiogram(),
299 "stock_market" => crate::time_series::stock_market(false),
300 "weather" => crate::time_series::weather(None),
301 _ => Err(DatasetsError::Other(format!(
302 "Remote dataset '{}' not yet implemented for loading",
303 name
304 ))),
305 }
306 } else {
307 Err(DatasetsError::Other(format!(
308 "Unsupported URL scheme for dataset '{}': {}",
309 name, entry.url
310 )))
311 }
312 } else {
313 Err(DatasetsError::Other(format!(
314 "Unknown dataset: '{}'. Available datasets: {:?}",
315 name,
316 registry.list_datasets()
317 )))
318 }
319}
320
321#[cfg(feature = "download")]
323#[allow(dead_code)]
324fn load_local_dataset(
325 name: &str,
326 relativepath: &str,
327 expected_sha256: &str,
328) -> Result<crate::utils::Dataset> {
329 use crate::loaders::{load_csv, CsvConfig};
330 use std::path::Path;
331
332 let workspace_root = env!("CARGO_MANIFEST_DIR");
334 let filepath = Path::new(workspace_root).join(relativepath);
335
336 if !filepath.exists() {
337 return Err(DatasetsError::Other(format!(
338 "Local dataset file not found: {}",
339 filepath.display()
340 )));
341 }
342
343 if expected_sha256 != "builtin" {
345 if let Ok(actual_hash) = crate::cache::sha256_hash_file(&filepath) {
346 if actual_hash != expected_sha256 {
347 return Err(DatasetsError::Other(format!(
348 "Hash verification failed for dataset '{}'. Expected: {}, Got: {}",
349 name, expected_sha256, actual_hash
350 )));
351 }
352 }
353 }
354
355 let config = CsvConfig::default().with_header(true);
357 let mut dataset = load_csv(&filepath, config)?;
358
359 dataset = dataset.with_description(format!("Local dataset: {}", name));
361
362 Ok(dataset)
363}
364
365#[cfg(not(feature = "download"))]
366#[allow(dead_code)]
368pub fn load_dataset_byname(_name: &str, _forcedownload: bool) -> Result<crate::utils::Dataset> {
369 Err(DatasetsError::Other(
370 "Download feature is not enabled. Recompile with --features _download".to_string(),
371 ))
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 #[test]
379 fn test_registry_creation() {
380 let registry = DatasetRegistry::new();
381 assert!(registry.entries.is_empty());
382 }
383
384 #[test]
385 fn test_registry_default() {
386 let registry = DatasetRegistry::default();
387 assert!(!registry.entries.is_empty());
388
389 assert!(registry.contains("example"));
391 assert!(registry.contains("sample_data"));
392
393 assert!(registry.contains("iris"));
395 assert!(registry.contains("boston"));
396 assert!(registry.contains("wine"));
397 assert!(registry.contains("digits"));
398 assert!(registry.contains("breast_cancer"));
399 assert!(registry.contains("diabetes"));
400 }
401
402 #[test]
403 fn test_registry_operations() {
404 let mut registry = DatasetRegistry::new();
405
406 let entry = RegistryEntry {
407 url: "https://example.com/test.csv",
408 sha256: "abcd1234",
409 };
410
411 registry.register("test_dataset".to_string(), entry);
412
413 assert!(registry.contains("test_dataset"));
414 assert!(!registry.contains("nonexistent"));
415
416 let retrieved = registry.get("test_dataset").unwrap();
417 assert_eq!(retrieved.url, "https://example.com/test.csv");
418 assert_eq!(retrieved.sha256, "abcd1234");
419
420 let datasets = registry.list_datasets();
421 assert_eq!(datasets.len(), 1);
422 assert!(datasets.contains(&"test_dataset".to_string()));
423 }
424
425 #[test]
426 fn test_get_registry() {
427 let registry = get_registry();
428 assert!(!registry.list_datasets().is_empty());
429 }
430
431 #[test]
432 fn test_registry_url_schemes() {
433 let registry = DatasetRegistry::default();
434
435 if let Some(iris_entry) = registry.get("iris") {
437 assert_eq!(iris_entry.url, "builtin://iris");
438 assert_eq!(iris_entry.sha256, "builtin");
439 }
440
441 if let Some(example_entry) = registry.get("example") {
443 assert_eq!(example_entry.url, "file://data/example.csv");
444 assert_eq!(
445 example_entry.sha256,
446 "c51c3ff2e8a5db28b1baed809a2ba29f29643e5a26ad476448eb3889996173d6"
447 );
448 }
449 }
450
451 #[test]
452 fn test_dataset_count() {
453 let registry = DatasetRegistry::default();
454 let datasets = registry.list_datasets();
455
456 assert_eq!(datasets.len(), 8);
458
459 let expected_datasets = vec![
461 "example",
462 "sample_data", "iris",
464 "boston",
465 "digits",
466 "wine",
467 "breast_cancer",
468 "diabetes", ];
470
471 for expected in expected_datasets {
472 assert!(
473 datasets.contains(&expected.to_string()),
474 "Dataset '{expected}' not found in registry"
475 );
476 }
477 }
478}