1use crate::cache::{fetch_data, RegistryEntry};
12use crate::error::{DatasetsError, Result};
13use crate::utils::Dataset;
14use scirs2_core::ndarray::{Array1, Array2};
15use serde::Deserialize;
16use std::collections::HashMap;
17use std::fs;
18
19lazy_static::lazy_static! {
21 static ref REGISTRY: HashMap<&'static str, RegistryEntry> = {
22 let mut registry = HashMap::new();
23
24 registry.insert("ecg.dat", RegistryEntry {
26 sha256: "f20ad3365fb9b7f845d0e5c48b6fe67081377ee466c3a220b7f69f35c8958baf",
27 url: "https://raw.githubusercontent.com/scipy/dataset-ecg/main/ecg.dat",
28 });
29
30 registry.insert("stock_market.csv", RegistryEntry {
32 sha256: "e6d5392bd79e82e3f6d7fe171d8c2fafae84b1a4e9e95a532ec252caa3053dc9",
33 url: "https://raw.githubusercontent.com/scirs/datasets/main/stock_market.csv",
34 });
35
36 registry.insert("weather.csv", RegistryEntry {
38 sha256: "f8bdaef6d968c1eddb0c0c7cf9c245b07d60ffe3a7d8e5ed8953f5750ee0f610",
39 url: "https://raw.githubusercontent.com/scirs/datasets/main/weather.csv",
40 });
41
42 registry
43 };
44}
45
46#[allow(dead_code)]
67pub fn electrocardiogram() -> Result<Dataset> {
68 let ecg_file = match fetch_data("ecg.dat", REGISTRY.get("ecg.dat")) {
70 Ok(path) => path,
71 Err(e) => {
72 return Err(DatasetsError::LoadingError(format!(
73 "Failed to fetch ECG data: {e}"
74 )))
75 }
76 };
77
78 let ecg_data = match fs::read(ecg_file) {
80 Ok(data) => data,
81 Err(e) => {
82 return Err(DatasetsError::LoadingError(format!(
83 "Failed to read ECG data: {e}"
84 )))
85 }
86 };
87
88 let mut ecg_values = Vec::with_capacity(ecg_data.len() / 2);
90 let mut i = 0;
91 while i < ecg_data.len() {
92 if i + 1 < ecg_data.len() {
93 let value = (ecg_data[i] as u16) | ((ecg_data[i + 1] as u16) << 8);
94 ecg_values.push(value);
95 }
96 i += 2;
97 }
98
99 let ecg_values = ecg_values
102 .into_iter()
103 .map(|x| (x as f64 - 1024.0) / 200.0)
104 .collect::<Vec<f64>>();
105
106 let ecg_array = Array1::from_vec(ecg_values);
107
108 let len = ecg_array.len();
110
111 let data = ecg_array.into_shape_with_order((len, 1)).unwrap();
113
114 let mut dataset = Dataset::new(data, None);
116 dataset = dataset
117 .with_featurenames(vec!["ecg".to_string()])
118 .with_description("Electrocardiogram (ECG) data, 5 minutes sampled at 360 Hz".to_string())
119 .with_metadata("sampling_rate", "360")
120 .with_metadata("units", "mV")
121 .with_metadata("duration", "5 minutes");
122
123 Ok(dataset)
124}
125
126#[derive(Debug, Deserialize)]
128struct StockPrice {
129 date: String,
130 open: f64,
131 #[allow(dead_code)]
132 high: f64,
133 #[allow(dead_code)]
134 low: f64,
135 close: f64,
136 #[allow(dead_code)]
137 volume: f64,
138 symbol: String,
139}
140
141#[allow(dead_code)]
167pub fn stock_market(returns: bool) -> Result<Dataset> {
168 let stock_file = match fetch_data("stock_market.csv", REGISTRY.get("stock_market.csv")) {
170 Ok(path) => path,
171 Err(e) => {
172 return Err(DatasetsError::LoadingError(format!(
173 "Failed to fetch stock market data: {e}"
174 )))
175 }
176 };
177
178 let file_content = match fs::read_to_string(&stock_file) {
180 Ok(content) => content,
181 Err(e) => {
182 return Err(DatasetsError::LoadingError(format!(
183 "Failed to read stock market data: {e}"
184 )))
185 }
186 };
187
188 let mut reader = csv::Reader::from_reader(file_content.as_bytes());
189 let records: Result<Vec<StockPrice>> = reader
190 .deserialize()
191 .map(|result| {
192 result.map_err(|e| DatasetsError::LoadingError(format!("CSV parsing error: {e}")))
193 })
194 .collect();
195
196 let records = records?;
197 if records.is_empty() {
198 return Err(DatasetsError::LoadingError(
199 "Stock market data is empty".to_string(),
200 ));
201 }
202
203 let mut symbols = Vec::new();
205 let mut dates = Vec::new();
206 for record in &records {
207 if !symbols.contains(&record.symbol) {
208 symbols.push(record.symbol.clone());
209 }
210 if !dates.contains(&record.date) {
211 dates.push(record.date.clone());
212 }
213 }
214
215 symbols.sort();
216 dates.sort();
217
218 let mut date_symbol_map = HashMap::new();
220 for record in &records {
221 date_symbol_map.insert((record.date.clone(), record.symbol.clone()), record);
222 }
223
224 let mut data = Array2::zeros((dates.len(), symbols.len()));
226
227 for (i, date) in dates.iter().enumerate() {
228 for (j, symbol) in symbols.iter().enumerate() {
229 if let Some(record) = date_symbol_map.get(&(date.clone(), symbol.clone())) {
230 data[[i, j]] = if returns {
231 record.close - record.open
232 } else {
233 record.close
234 };
235 }
236 }
237 }
238
239 let mut dataset = Dataset::new(data, None);
241 dataset = dataset
242 .with_featurenames(symbols.clone())
243 .with_description(format!(
244 "Stock market data for {} companies from {} to {}",
245 symbols.len(),
246 dates.first().unwrap_or(&"unknown".to_string()),
247 dates.last().unwrap_or(&"unknown".to_string())
248 ))
249 .with_metadata("n_symbols", &symbols.len().to_string())
250 .with_metadata(
251 "start_date",
252 dates.first().unwrap_or(&"unknown".to_string()),
253 )
254 .with_metadata("end_date", dates.last().unwrap_or(&"unknown".to_string()))
255 .with_metadata("data_type", if returns { "_returns" } else { "prices" });
256
257 Ok(dataset)
258}
259
260#[derive(Debug, Deserialize)]
262struct WeatherObservation {
263 date: String,
264 temperature: f64,
265 humidity: f64,
266 pressure: f64,
267 wind_speed: f64,
268 precipitation: f64,
269 location: String,
270}
271
272#[allow(dead_code)]
305pub fn weather(feature: Option<&str>) -> Result<Dataset> {
306 let valid_features = vec![
308 "temperature",
309 "humidity",
310 "pressure",
311 "wind_speed",
312 "precipitation",
313 ];
314
315 if let Some(f) = feature {
316 if !valid_features.contains(&f) {
317 return Err(DatasetsError::InvalidFormat(format!(
318 "Invalid _feature: {f}. Valid features are: {valid_features:?}"
319 )));
320 }
321 }
322
323 let weather_file = match fetch_data("weather.csv", REGISTRY.get("weather.csv")) {
325 Ok(path) => path,
326 Err(e) => {
327 return Err(DatasetsError::LoadingError(format!(
328 "Failed to fetch weather data: {e}"
329 )))
330 }
331 };
332
333 let file_content = match fs::read_to_string(&weather_file) {
335 Ok(content) => content,
336 Err(e) => {
337 return Err(DatasetsError::LoadingError(format!(
338 "Failed to read weather data: {e}"
339 )))
340 }
341 };
342
343 let mut reader = csv::Reader::from_reader(file_content.as_bytes());
344 let records: Result<Vec<WeatherObservation>> = reader
345 .deserialize()
346 .map(|result| {
347 result.map_err(|e| DatasetsError::LoadingError(format!("CSV parsing error: {e}")))
348 })
349 .collect();
350
351 let records = records?;
352 if records.is_empty() {
353 return Err(DatasetsError::LoadingError(
354 "Weather data is empty".to_string(),
355 ));
356 }
357
358 let mut locations = Vec::new();
360 let mut dates = Vec::new();
361 for record in &records {
362 if !locations.contains(&record.location) {
363 locations.push(record.location.clone());
364 }
365 if !dates.contains(&record.date) {
366 dates.push(record.date.clone());
367 }
368 }
369
370 locations.sort();
371 dates.sort();
372
373 let mut date_location_map = HashMap::new();
375 for record in &records {
376 date_location_map.insert((record.date.clone(), record.location.clone()), record);
377 }
378
379 let mut dataset = match feature {
380 Some(feat) => {
381 let mut data = Array2::zeros((dates.len(), locations.len()));
383
384 for (i, date) in dates.iter().enumerate() {
385 for (j, location) in locations.iter().enumerate() {
386 if let Some(record) = date_location_map.get(&(date.clone(), location.clone())) {
387 data[[i, j]] = match feat {
388 "temperature" => record.temperature,
389 "humidity" => record.humidity,
390 "pressure" => record.pressure,
391 "wind_speed" => record.wind_speed,
392 "precipitation" => record.precipitation,
393 _ => 0.0, };
395 }
396 }
397 }
398
399 let mut ds = Dataset::new(data, None);
401
402 ds = ds
404 .with_featurenames(locations.clone())
405 .with_description(format!(
406 "Weather {} data for {} locations from {} to {}",
407 feat,
408 locations.len(),
409 dates.first().unwrap_or(&"unknown".to_string()),
410 dates.last().unwrap_or(&"unknown".to_string())
411 ))
412 .with_metadata("_feature", feat)
413 .with_metadata("n_locations", &locations.len().to_string())
414 .with_metadata(
415 "start_date",
416 dates.first().unwrap_or(&"unknown".to_string()),
417 )
418 .with_metadata("end_date", dates.last().unwrap_or(&"unknown".to_string()));
419
420 ds
421 }
422 None => {
423 let n_features = valid_features.len();
426 let mut data = Array2::zeros((dates.len(), n_features * locations.len()));
427
428 for (i, date) in dates.iter().enumerate() {
429 for (j, location) in locations.iter().enumerate() {
430 if let Some(record) = date_location_map.get(&(date.clone(), location.clone())) {
431 let base_col = j * n_features;
433
434 data[[i, base_col]] = record.temperature;
436 data[[i, base_col + 1]] = record.humidity;
437 data[[i, base_col + 2]] = record.pressure;
438 data[[i, base_col + 3]] = record.wind_speed;
439 data[[i, base_col + 4]] = record.precipitation;
440 }
441 }
442 }
443
444 let mut featurenames = Vec::with_capacity(n_features * locations.len());
446 for location in &locations {
447 for feat in &valid_features {
448 featurenames.push(format!("{location}_{feat}"));
449 }
450 }
451
452 let mut ds = Dataset::new(data, None);
454 ds = ds
455 .with_featurenames(featurenames)
456 .with_description(format!(
457 "Weather data (all features) for {} locations from {} to {}",
458 locations.len(),
459 dates.first().unwrap_or(&"unknown".to_string()),
460 dates.last().unwrap_or(&"unknown".to_string())
461 ))
462 .with_metadata("features", &valid_features.join(","))
463 .with_metadata("n_locations", &locations.len().to_string())
464 .with_metadata(
465 "start_date",
466 dates.first().unwrap_or(&"unknown".to_string()),
467 )
468 .with_metadata("end_date", dates.last().unwrap_or(&"unknown".to_string()));
469
470 ds
471 }
472 };
473
474 dataset = dataset.with_metadata("locations", &locations.join(","));
476
477 Ok(dataset)
478}
479
480