scirs2_datasets/
registry.rs

1//! Dataset registry system for managing dataset metadata and locations
2
3use crate::cache::RegistryEntry;
4use crate::error::{DatasetsError, Result};
5use std::collections::HashMap;
6
7/// Dataset metadata information
8#[derive(Debug, Clone, Default)]
9pub struct DatasetMetadata {
10    /// Name of the dataset
11    pub name: String,
12    /// Description of the dataset
13    pub description: String,
14    /// Number of samples in the dataset
15    pub n_samples: usize,
16    /// Number of features in the dataset
17    pub n_features: usize,
18    /// Whether this is a classification or regression dataset
19    pub task_type: String,
20    /// Optional target names for classification problems
21    pub targetnames: Option<Vec<String>>,
22    /// Optional feature names
23    pub featurenames: Option<Vec<String>>,
24    /// Optional download URL
25    pub url: Option<String>,
26    /// Optional checksum for verification
27    pub checksum: Option<String>,
28}
29
30/// Global dataset registry containing metadata for downloadable datasets
31pub struct DatasetRegistry {
32    /// Map from dataset name to registry entry
33    entries: HashMap<String, RegistryEntry>,
34}
35
36impl Default for DatasetRegistry {
37    fn default() -> Self {
38        let mut registry = Self::new();
39        registry.populate_default_datasets();
40        registry
41    }
42}
43
44impl DatasetRegistry {
45    /// Create a new empty registry
46    pub fn new() -> Self {
47        Self {
48            entries: HashMap::new(),
49        }
50    }
51
52    /// Register a new dataset with the given name and metadata
53    pub fn register(&mut self, name: String, entry: RegistryEntry) {
54        self.entries.insert(name, entry);
55    }
56
57    /// Get a registry entry by name
58    pub fn get(&self, name: &str) -> Option<&RegistryEntry> {
59        self.entries.get(name)
60    }
61
62    /// List all available dataset names
63    pub fn list_datasets(&self) -> Vec<String> {
64        self.entries.keys().cloned().collect()
65    }
66
67    /// Check if a dataset is registered
68    pub fn contains(&self, name: &str) -> bool {
69        self.entries.contains_key(name)
70    }
71
72    /// Get metadata for a dataset
73    pub fn get_metadata(&self, name: &str) -> Result<DatasetMetadata> {
74        match name {
75            "iris" => Ok(DatasetMetadata {
76                name: "Iris".to_string(),
77                description: "Classic iris flower dataset for classification".to_string(),
78                n_samples: 150,
79                n_features: 4,
80                task_type: "classification".to_string(),
81                targetnames: Some(vec![
82                    "setosa".to_string(),
83                    "versicolor".to_string(),
84                    "virginica".to_string(),
85                ]),
86                featurenames: Some(vec![
87                    "sepal_length".to_string(),
88                    "sepal_width".to_string(),
89                    "petal_length".to_string(),
90                    "petal_width".to_string(),
91                ]),
92                url: None,
93                checksum: None,
94            }),
95            "boston" => Ok(DatasetMetadata {
96                name: "Boston Housing".to_string(),
97                description: "Boston housing prices dataset for regression".to_string(),
98                n_samples: 506,
99                n_features: 13,
100                task_type: "regression".to_string(),
101                targetnames: None,
102                featurenames: None,
103                url: None,
104                checksum: None,
105            }),
106            "digits" => Ok(DatasetMetadata {
107                name: "Digits".to_string(),
108                description: "Hand-written digits dataset for image classification".to_string(),
109                n_samples: 1797,
110                n_features: 64,
111                task_type: "classification".to_string(),
112                targetnames: Some(vec![
113                    "0".to_string(),
114                    "1".to_string(),
115                    "2".to_string(),
116                    "3".to_string(),
117                    "4".to_string(),
118                    "5".to_string(),
119                    "6".to_string(),
120                    "7".to_string(),
121                    "8".to_string(),
122                    "9".to_string(),
123                ]),
124                featurenames: None,
125                url: None,
126                checksum: None,
127            }),
128            "wine" => Ok(DatasetMetadata {
129                name: "Wine".to_string(),
130                description: "Wine recognition dataset for classification".to_string(),
131                n_samples: 178,
132                n_features: 13,
133                task_type: "classification".to_string(),
134                targetnames: Some(vec![
135                    "class_0".to_string(),
136                    "class_1".to_string(),
137                    "class_2".to_string(),
138                ]),
139                featurenames: None,
140                url: None,
141                checksum: None,
142            }),
143            "breast_cancer" => Ok(DatasetMetadata {
144                name: "Breast Cancer".to_string(),
145                description: "Breast cancer wisconsin dataset for classification".to_string(),
146                n_samples: 569,
147                n_features: 30,
148                task_type: "classification".to_string(),
149                targetnames: Some(vec!["malignant".to_string(), "benign".to_string()]),
150                featurenames: None,
151                url: None,
152                checksum: None,
153            }),
154            "diabetes" => Ok(DatasetMetadata {
155                name: "Diabetes".to_string(),
156                description: "Diabetes dataset for regression".to_string(),
157                n_samples: 442,
158                n_features: 10,
159                task_type: "regression".to_string(),
160                targetnames: None,
161                featurenames: None,
162                url: None,
163                checksum: None,
164            }),
165            _ => Err(DatasetsError::Other(format!("Unknown dataset: {name}"))),
166        }
167    }
168
169    /// Populate the registry with default datasets
170    ///
171    /// This includes both local sample datasets and references to potential remote datasets.
172    /// Local datasets use verified SHA256 hashes computed from actual files.
173    fn populate_default_datasets(&mut self) {
174        // Local sample datasets (with verified SHA256 hashes)
175        self.register(
176            "example".to_string(),
177            RegistryEntry {
178                url: "file://data/example.csv",
179                sha256: "c51c3ff2e8a5db28b1baed809a2ba29f29643e5a26ad476448eb3889996173d6",
180            },
181        );
182
183        self.register(
184            "sample_data".to_string(),
185            RegistryEntry {
186                url: "file://examples/sample_data.csv",
187                sha256: "59cceb2c80692ee2c1c3b607335d1feb983ceed24214d1ffc2eace9f3ce5ab47",
188            },
189        );
190
191        // Built-in toy datasets (no hash needed as they're embedded in code)
192        self.register_toy_dataset("iris", "Classic iris flower dataset for classification");
193        self.register_toy_dataset("boston", "Boston housing prices dataset for regression");
194        self.register_toy_dataset(
195            "digits",
196            "Hand-written digits dataset for image classification",
197        );
198        self.register_toy_dataset("wine", "Wine recognition dataset for classification");
199        self.register_toy_dataset(
200            "breast_cancer",
201            "Breast cancer wisconsin dataset for classification",
202        );
203        self.register_toy_dataset("diabetes", "Diabetes dataset for regression");
204
205        // Future remote datasets (commented out until available)
206        /*
207        self.register(
208            "california_housing".to_string(),
209            RegistryEntry {
210                url: "https://raw.githubusercontent.com/cool-japan/scirs-datasets/main/california_housing.csv",
211                sha256: "a1b2c3d4e5f6789012345678901234567890abcdef1234567890abcdef123456",
212            },
213        );
214
215        self.register(
216            "electrocardiogram".to_string(),
217            RegistryEntry {
218                url: "https://raw.githubusercontent.com/cool-japan/scirs-datasets/main/electrocardiogram.json",
219                sha256: "def0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcd",
220            },
221        );
222
223        self.register(
224            "stock_market".to_string(),
225            RegistryEntry {
226                url: "https://raw.githubusercontent.com/cool-japan/scirs-datasets/main/stock_market.json",
227                sha256: "456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef012345",
228            },
229        );
230
231        self.register(
232            "weather".to_string(),
233            RegistryEntry {
234                url: "https://raw.githubusercontent.com/cool-japan/scirs-datasets/main/weather.json",
235                sha256: "789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789",
236            },
237        );
238        */
239    }
240
241    /// Register a toy dataset (built-in datasets don't need URLs or hashes)
242    fn register_toy_dataset(&mut self, name: &str, _description: &str) {
243        let url = match name {
244            "iris" => "builtin://iris",
245            "boston" => "builtin://boston",
246            "digits" => "builtin://digits",
247            "wine" => "builtin://wine",
248            "breast_cancer" => "builtin://breast_cancer",
249            "diabetes" => "builtin://diabetes",
250            _ => "builtin://unknown",
251        };
252
253        self.register(
254            name.to_string(),
255            RegistryEntry {
256                url,
257                sha256: "builtin", // Special marker for built-in datasets
258            },
259        );
260    }
261}
262
263/// Get the global dataset registry
264#[allow(dead_code)]
265pub fn get_registry() -> DatasetRegistry {
266    DatasetRegistry::default()
267}
268
269/// Load a dataset by name from the registry
270#[cfg(feature = "download")]
271#[allow(dead_code)]
272pub fn load_dataset_byname(name: &str, forcedownload: bool) -> Result<crate::utils::Dataset> {
273    let registry = get_registry();
274
275    if let Some(entry) = registry.get(name) {
276        // Handle different URL schemes
277        if entry.url.starts_with("builtin://") {
278            // Built-in toy datasets
279            match name {
280                "iris" => crate::toy::load_iris(),
281                "boston" => crate::toy::load_boston(),
282                "digits" => crate::toy::load_digits(),
283                "wine" => crate::sample::load_wine(false),
284                "breast_cancer" => crate::toy::load_breast_cancer(),
285                "diabetes" => crate::toy::load_diabetes(),
286                _ => Err(DatasetsError::Other(format!(
287                    "Built-in dataset '{}' not implemented",
288                    name
289                ))),
290            }
291        } else if entry.url.starts_with("file://") {
292            // Local file datasets
293            load_local_dataset(name, &entry.url[7..], entry.sha256) // Remove "file://" prefix
294        } else if entry.url.starts_with("http") {
295            // Remote datasets (when available)
296            match name {
297                "california_housing" => crate::sample::load_california_housing(forcedownload),
298                "electrocardiogram" => crate::time_series::electrocardiogram(),
299                "stock_market" => crate::time_series::stock_market(false),
300                "weather" => crate::time_series::weather(None),
301                _ => Err(DatasetsError::Other(format!(
302                    "Remote dataset '{}' not yet implemented for loading",
303                    name
304                ))),
305            }
306        } else {
307            Err(DatasetsError::Other(format!(
308                "Unsupported URL scheme for dataset '{}': {}",
309                name, entry.url
310            )))
311        }
312    } else {
313        Err(DatasetsError::Other(format!(
314            "Unknown dataset: '{}'. Available datasets: {:?}",
315            name,
316            registry.list_datasets()
317        )))
318    }
319}
320
321/// Load a local dataset file
322#[cfg(feature = "download")]
323#[allow(dead_code)]
324fn load_local_dataset(
325    name: &str,
326    relativepath: &str,
327    expected_sha256: &str,
328) -> Result<crate::utils::Dataset> {
329    use crate::loaders::{load_csv, CsvConfig};
330    use std::path::Path;
331
332    // Build absolute path from workspace root
333    let workspace_root = env!("CARGO_MANIFEST_DIR");
334    let filepath = Path::new(workspace_root).join(relativepath);
335
336    if !filepath.exists() {
337        return Err(DatasetsError::Other(format!(
338            "Local dataset file not found: {}",
339            filepath.display()
340        )));
341    }
342
343    // Verify SHA256 hash
344    if expected_sha256 != "builtin" {
345        if let Ok(actual_hash) = crate::cache::sha256_hash_file(&filepath) {
346            if actual_hash != expected_sha256 {
347                return Err(DatasetsError::Other(format!(
348                    "Hash verification failed for dataset '{}'. Expected: {}, Got: {}",
349                    name, expected_sha256, actual_hash
350                )));
351            }
352        }
353    }
354
355    // Load the CSV file
356    let config = CsvConfig::default().with_header(true);
357    let mut dataset = load_csv(&filepath, config)?;
358
359    // Add metadata
360    dataset = dataset.with_description(format!("Local dataset: {}", name));
361
362    Ok(dataset)
363}
364
365#[cfg(not(feature = "download"))]
366/// Load a dataset by name from the registry (stub for when download feature is disabled)
367#[allow(dead_code)]
368pub fn load_dataset_byname(_name: &str, _forcedownload: bool) -> Result<crate::utils::Dataset> {
369    Err(DatasetsError::Other(
370        "Download feature is not enabled. Recompile with --features _download".to_string(),
371    ))
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    #[test]
379    fn test_registry_creation() {
380        let registry = DatasetRegistry::new();
381        assert!(registry.entries.is_empty());
382    }
383
384    #[test]
385    fn test_registry_default() {
386        let registry = DatasetRegistry::default();
387        assert!(!registry.entries.is_empty());
388
389        // Test local datasets
390        assert!(registry.contains("example"));
391        assert!(registry.contains("sample_data"));
392
393        // Test built-in toy datasets
394        assert!(registry.contains("iris"));
395        assert!(registry.contains("boston"));
396        assert!(registry.contains("wine"));
397        assert!(registry.contains("digits"));
398        assert!(registry.contains("breast_cancer"));
399        assert!(registry.contains("diabetes"));
400    }
401
402    #[test]
403    fn test_registry_operations() {
404        let mut registry = DatasetRegistry::new();
405
406        let entry = RegistryEntry {
407            url: "https://example.com/test.csv",
408            sha256: "abcd1234",
409        };
410
411        registry.register("test_dataset".to_string(), entry);
412
413        assert!(registry.contains("test_dataset"));
414        assert!(!registry.contains("nonexistent"));
415
416        let retrieved = registry.get("test_dataset").unwrap();
417        assert_eq!(retrieved.url, "https://example.com/test.csv");
418        assert_eq!(retrieved.sha256, "abcd1234");
419
420        let datasets = registry.list_datasets();
421        assert_eq!(datasets.len(), 1);
422        assert!(datasets.contains(&"test_dataset".to_string()));
423    }
424
425    #[test]
426    fn test_get_registry() {
427        let registry = get_registry();
428        assert!(!registry.list_datasets().is_empty());
429    }
430
431    #[test]
432    fn test_registry_url_schemes() {
433        let registry = DatasetRegistry::default();
434
435        // Test built-in datasets have builtin:// URLs
436        if let Some(iris_entry) = registry.get("iris") {
437            assert_eq!(iris_entry.url, "builtin://iris");
438            assert_eq!(iris_entry.sha256, "builtin");
439        }
440
441        // Test local datasets have file:// URLs
442        if let Some(example_entry) = registry.get("example") {
443            assert_eq!(example_entry.url, "file://data/example.csv");
444            assert_eq!(
445                example_entry.sha256,
446                "c51c3ff2e8a5db28b1baed809a2ba29f29643e5a26ad476448eb3889996173d6"
447            );
448        }
449    }
450
451    #[test]
452    fn test_dataset_count() {
453        let registry = DatasetRegistry::default();
454        let datasets = registry.list_datasets();
455
456        // Should have 2 local datasets + 6 built-in toy datasets = 8 total
457        assert_eq!(datasets.len(), 8);
458
459        // Verify all expected datasets are present
460        let expected_datasets = vec![
461            "example",
462            "sample_data", // local
463            "iris",
464            "boston",
465            "digits",
466            "wine",
467            "breast_cancer",
468            "diabetes", // built-in
469        ];
470
471        for expected in expected_datasets {
472            assert!(
473                datasets.contains(&expected.to_string()),
474                "Dataset '{expected}' not found in registry"
475            );
476        }
477    }
478}