1use crate::cache::{fetch_data, RegistryEntry};
12use crate::error::{DatasetsError, Result};
13use crate::utils::Dataset;
14use scirs2_core::ndarray::{Array1, Array2};
15use serde::Deserialize;
16use std::collections::HashMap;
17use std::fs;
18
19lazy_static::lazy_static! {
21    static ref REGISTRY: HashMap<&'static str, RegistryEntry> = {
22        let mut registry = HashMap::new();
23
24        registry.insert("ecg.dat", RegistryEntry {
26            sha256: "f20ad3365fb9b7f845d0e5c48b6fe67081377ee466c3a220b7f69f35c8958baf",
27            url: "https://raw.githubusercontent.com/scipy/dataset-ecg/main/ecg.dat",
28        });
29
30        registry.insert("stock_market.csv", RegistryEntry {
32            sha256: "e6d5392bd79e82e3f6d7fe171d8c2fafae84b1a4e9e95a532ec252caa3053dc9",
33            url: "https://raw.githubusercontent.com/scirs/datasets/main/stock_market.csv",
34        });
35
36        registry.insert("weather.csv", RegistryEntry {
38            sha256: "f8bdaef6d968c1eddb0c0c7cf9c245b07d60ffe3a7d8e5ed8953f5750ee0f610",
39            url: "https://raw.githubusercontent.com/scirs/datasets/main/weather.csv",
40        });
41
42        registry
43    };
44}
45
46#[allow(dead_code)]
67pub fn electrocardiogram() -> Result<Dataset> {
68    let ecg_file = match fetch_data("ecg.dat", REGISTRY.get("ecg.dat")) {
70        Ok(path) => path,
71        Err(e) => {
72            return Err(DatasetsError::LoadingError(format!(
73                "Failed to fetch ECG data: {e}"
74            )))
75        }
76    };
77
78    let ecg_data = match fs::read(ecg_file) {
80        Ok(data) => data,
81        Err(e) => {
82            return Err(DatasetsError::LoadingError(format!(
83                "Failed to read ECG data: {e}"
84            )))
85        }
86    };
87
88    let mut ecg_values = Vec::with_capacity(ecg_data.len() / 2);
90    let mut i = 0;
91    while i < ecg_data.len() {
92        if i + 1 < ecg_data.len() {
93            let value = (ecg_data[i] as u16) | ((ecg_data[i + 1] as u16) << 8);
94            ecg_values.push(value);
95        }
96        i += 2;
97    }
98
99    let ecg_values = ecg_values
102        .into_iter()
103        .map(|x| (x as f64 - 1024.0) / 200.0)
104        .collect::<Vec<f64>>();
105
106    let ecg_array = Array1::from_vec(ecg_values);
107
108    let len = ecg_array.len();
110
111    let data = ecg_array.into_shape_with_order((len, 1)).unwrap();
113
114    let mut dataset = Dataset::new(data, None);
116    dataset = dataset
117        .with_featurenames(vec!["ecg".to_string()])
118        .with_description("Electrocardiogram (ECG) data, 5 minutes sampled at 360 Hz".to_string())
119        .with_metadata("sampling_rate", "360")
120        .with_metadata("units", "mV")
121        .with_metadata("duration", "5 minutes");
122
123    Ok(dataset)
124}
125
126#[derive(Debug, Deserialize)]
128struct StockPrice {
129    date: String,
130    open: f64,
131    #[allow(dead_code)]
132    high: f64,
133    #[allow(dead_code)]
134    low: f64,
135    close: f64,
136    #[allow(dead_code)]
137    volume: f64,
138    symbol: String,
139}
140
141#[allow(dead_code)]
167pub fn stock_market(returns: bool) -> Result<Dataset> {
168    let stock_file = match fetch_data("stock_market.csv", REGISTRY.get("stock_market.csv")) {
170        Ok(path) => path,
171        Err(e) => {
172            return Err(DatasetsError::LoadingError(format!(
173                "Failed to fetch stock market data: {e}"
174            )))
175        }
176    };
177
178    let file_content = match fs::read_to_string(&stock_file) {
180        Ok(content) => content,
181        Err(e) => {
182            return Err(DatasetsError::LoadingError(format!(
183                "Failed to read stock market data: {e}"
184            )))
185        }
186    };
187
188    let mut reader = csv::Reader::from_reader(file_content.as_bytes());
189    let records: Result<Vec<StockPrice>> = reader
190        .deserialize()
191        .map(|result| {
192            result.map_err(|e| DatasetsError::LoadingError(format!("CSV parsing error: {e}")))
193        })
194        .collect();
195
196    let records = records?;
197    if records.is_empty() {
198        return Err(DatasetsError::LoadingError(
199            "Stock market data is empty".to_string(),
200        ));
201    }
202
203    let mut symbols = Vec::new();
205    let mut dates = Vec::new();
206    for record in &records {
207        if !symbols.contains(&record.symbol) {
208            symbols.push(record.symbol.clone());
209        }
210        if !dates.contains(&record.date) {
211            dates.push(record.date.clone());
212        }
213    }
214
215    symbols.sort();
216    dates.sort();
217
218    let mut date_symbol_map = HashMap::new();
220    for record in &records {
221        date_symbol_map.insert((record.date.clone(), record.symbol.clone()), record);
222    }
223
224    let mut data = Array2::zeros((dates.len(), symbols.len()));
226
227    for (i, date) in dates.iter().enumerate() {
228        for (j, symbol) in symbols.iter().enumerate() {
229            if let Some(record) = date_symbol_map.get(&(date.clone(), symbol.clone())) {
230                data[[i, j]] = if returns {
231                    record.close - record.open
232                } else {
233                    record.close
234                };
235            }
236        }
237    }
238
239    let mut dataset = Dataset::new(data, None);
241    dataset = dataset
242        .with_featurenames(symbols.clone())
243        .with_description(format!(
244            "Stock market data for {} companies from {} to {}",
245            symbols.len(),
246            dates.first().unwrap_or(&"unknown".to_string()),
247            dates.last().unwrap_or(&"unknown".to_string())
248        ))
249        .with_metadata("n_symbols", &symbols.len().to_string())
250        .with_metadata(
251            "start_date",
252            dates.first().unwrap_or(&"unknown".to_string()),
253        )
254        .with_metadata("end_date", dates.last().unwrap_or(&"unknown".to_string()))
255        .with_metadata("data_type", if returns { "_returns" } else { "prices" });
256
257    Ok(dataset)
258}
259
260#[derive(Debug, Deserialize)]
262struct WeatherObservation {
263    date: String,
264    temperature: f64,
265    humidity: f64,
266    pressure: f64,
267    wind_speed: f64,
268    precipitation: f64,
269    location: String,
270}
271
272#[allow(dead_code)]
305pub fn weather(feature: Option<&str>) -> Result<Dataset> {
306    let valid_features = vec![
308        "temperature",
309        "humidity",
310        "pressure",
311        "wind_speed",
312        "precipitation",
313    ];
314
315    if let Some(f) = feature {
316        if !valid_features.contains(&f) {
317            return Err(DatasetsError::InvalidFormat(format!(
318                "Invalid _feature: {f}. Valid features are: {valid_features:?}"
319            )));
320        }
321    }
322
323    let weather_file = match fetch_data("weather.csv", REGISTRY.get("weather.csv")) {
325        Ok(path) => path,
326        Err(e) => {
327            return Err(DatasetsError::LoadingError(format!(
328                "Failed to fetch weather data: {e}"
329            )))
330        }
331    };
332
333    let file_content = match fs::read_to_string(&weather_file) {
335        Ok(content) => content,
336        Err(e) => {
337            return Err(DatasetsError::LoadingError(format!(
338                "Failed to read weather data: {e}"
339            )))
340        }
341    };
342
343    let mut reader = csv::Reader::from_reader(file_content.as_bytes());
344    let records: Result<Vec<WeatherObservation>> = reader
345        .deserialize()
346        .map(|result| {
347            result.map_err(|e| DatasetsError::LoadingError(format!("CSV parsing error: {e}")))
348        })
349        .collect();
350
351    let records = records?;
352    if records.is_empty() {
353        return Err(DatasetsError::LoadingError(
354            "Weather data is empty".to_string(),
355        ));
356    }
357
358    let mut locations = Vec::new();
360    let mut dates = Vec::new();
361    for record in &records {
362        if !locations.contains(&record.location) {
363            locations.push(record.location.clone());
364        }
365        if !dates.contains(&record.date) {
366            dates.push(record.date.clone());
367        }
368    }
369
370    locations.sort();
371    dates.sort();
372
373    let mut date_location_map = HashMap::new();
375    for record in &records {
376        date_location_map.insert((record.date.clone(), record.location.clone()), record);
377    }
378
379    let mut dataset = match feature {
380        Some(feat) => {
381            let mut data = Array2::zeros((dates.len(), locations.len()));
383
384            for (i, date) in dates.iter().enumerate() {
385                for (j, location) in locations.iter().enumerate() {
386                    if let Some(record) = date_location_map.get(&(date.clone(), location.clone())) {
387                        data[[i, j]] = match feat {
388                            "temperature" => record.temperature,
389                            "humidity" => record.humidity,
390                            "pressure" => record.pressure,
391                            "wind_speed" => record.wind_speed,
392                            "precipitation" => record.precipitation,
393                            _ => 0.0, };
395                    }
396                }
397            }
398
399            let mut ds = Dataset::new(data, None);
401
402            ds = ds
404                .with_featurenames(locations.clone())
405                .with_description(format!(
406                    "Weather {} data for {} locations from {} to {}",
407                    feat,
408                    locations.len(),
409                    dates.first().unwrap_or(&"unknown".to_string()),
410                    dates.last().unwrap_or(&"unknown".to_string())
411                ))
412                .with_metadata("_feature", feat)
413                .with_metadata("n_locations", &locations.len().to_string())
414                .with_metadata(
415                    "start_date",
416                    dates.first().unwrap_or(&"unknown".to_string()),
417                )
418                .with_metadata("end_date", dates.last().unwrap_or(&"unknown".to_string()));
419
420            ds
421        }
422        None => {
423            let n_features = valid_features.len();
426            let mut data = Array2::zeros((dates.len(), n_features * locations.len()));
427
428            for (i, date) in dates.iter().enumerate() {
429                for (j, location) in locations.iter().enumerate() {
430                    if let Some(record) = date_location_map.get(&(date.clone(), location.clone())) {
431                        let base_col = j * n_features;
433
434                        data[[i, base_col]] = record.temperature;
436                        data[[i, base_col + 1]] = record.humidity;
437                        data[[i, base_col + 2]] = record.pressure;
438                        data[[i, base_col + 3]] = record.wind_speed;
439                        data[[i, base_col + 4]] = record.precipitation;
440                    }
441                }
442            }
443
444            let mut featurenames = Vec::with_capacity(n_features * locations.len());
446            for location in &locations {
447                for feat in &valid_features {
448                    featurenames.push(format!("{location}_{feat}"));
449                }
450            }
451
452            let mut ds = Dataset::new(data, None);
454            ds = ds
455                .with_featurenames(featurenames)
456                .with_description(format!(
457                    "Weather data (all features) for {} locations from {} to {}",
458                    locations.len(),
459                    dates.first().unwrap_or(&"unknown".to_string()),
460                    dates.last().unwrap_or(&"unknown".to_string())
461                ))
462                .with_metadata("features", &valid_features.join(","))
463                .with_metadata("n_locations", &locations.len().to_string())
464                .with_metadata(
465                    "start_date",
466                    dates.first().unwrap_or(&"unknown".to_string()),
467                )
468                .with_metadata("end_date", dates.last().unwrap_or(&"unknown".to_string()));
469
470            ds
471        }
472    };
473
474    dataset = dataset.with_metadata("locations", &locations.join(","));
476
477    Ok(dataset)
478}
479
480