1use std::collections::HashMap;
10use std::io::Read;
11use std::path::Path;
12use std::time::Duration;
13
14use scirs2_core::ndarray::{Array1, Array2};
15use serde::{Deserialize, Serialize};
16
17use crate::cache::DatasetCache;
18use crate::error::{DatasetsError, Result};
19use crate::loaders::{load_csv, CsvConfig};
20use crate::utils::Dataset;
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ExternalConfig {
25 pub timeout_seconds: u64,
27 pub max_retries: u32,
29 pub user_agent: String,
31 pub headers: HashMap<String, String>,
33 pub verify_ssl: bool,
35 pub use_cache: bool,
37}
38
39impl Default for ExternalConfig {
40 fn default() -> Self {
41 Self {
42 timeout_seconds: 300, max_retries: 3,
44 user_agent: "scirs2-datasets/0.1.0".to_string(),
45 headers: HashMap::new(),
46 verify_ssl: true,
47 use_cache: true,
48 }
49 }
50}
51
52pub type ProgressCallback = Box<dyn Fn(u64, u64) + Send + Sync>;
54
55pub struct ExternalClient {
57 config: ExternalConfig,
58 cache: DatasetCache,
59 #[cfg(feature = "download")]
60 client: reqwest::Client,
61}
62
63impl ExternalClient {
64 pub fn new() -> Result<Self> {
66 Self::with_config(ExternalConfig::default())
67 }
68
69 pub fn with_config(config: ExternalConfig) -> Result<Self> {
71 let cache = DatasetCache::new(crate::cache::get_cachedir()?);
72
73 #[cfg(feature = "download")]
74 let client = {
75 let mut builder = reqwest::Client::builder()
76 .timeout(Duration::from_secs(config.timeout_seconds))
77 .user_agent(&config.user_agent);
78
79 if !config.verify_ssl {
80 builder = builder.danger_accept_invalid_certs(true);
81 }
82
83 builder
84 .build()
85 .map_err(|e| DatasetsError::IoError(std::io::Error::other(e)))?
86 };
87
88 Ok(Self {
89 config,
90 cache,
91 #[cfg(feature = "download")]
92 client,
93 })
94 }
95
96 #[cfg(feature = "download")]
98 pub async fn download_dataset(
99 &self,
100 url: &str,
101 progress: Option<ProgressCallback>,
102 ) -> Result<Dataset> {
103 if self.config.use_cache {
105 let cache_key = format!("external_{}", blake3::hash(url.as_bytes()).to_hex());
106 if let Ok(cached_data) = self.cache.read_cached(&cache_key) {
107 return self.parse_cached_data(&cached_data);
108 }
109 }
110
111 let response = self.make_request(url).await?;
113 let total_size = response.content_length().unwrap_or(0);
114
115 let mut downloaded = 0u64;
116 let mut buffer = Vec::new();
117 let mut stream = response.bytes_stream();
118
119 use futures_util::StreamExt;
120 while let Some(chunk) = stream.next().await {
121 let chunk = chunk.map_err(|e| DatasetsError::IoError(std::io::Error::other(e)))?;
122 downloaded += chunk.len() as u64;
123 buffer.extend_from_slice(&chunk);
124
125 if let Some(ref callback) = progress {
126 callback(downloaded, total_size);
127 }
128 }
129
130 if self.config.use_cache {
132 let cache_key = format!("external_{}", blake3::hash(url.as_bytes()).to_hex());
133 let _ = self.cache.put(&cache_key, &buffer);
134 }
135
136 self.parse_downloaded_data(url, &buffer)
138 }
139
140 #[cfg(feature = "download")]
142 pub fn download_dataset_sync(
143 &self,
144 url: &str,
145 progress: Option<ProgressCallback>,
146 ) -> Result<Dataset> {
147 let rt = tokio::runtime::Runtime::new()
149 .map_err(|e| DatasetsError::IoError(std::io::Error::other(e)))?;
150 rt.block_on(self.download_dataset(url, progress))
151 }
152
153 #[cfg(not(feature = "download"))]
155 pub fn download_dataset_sync(
156 &self,
157 url: &str,
158 progress: Option<ProgressCallback>,
159 ) -> Result<Dataset> {
160 self.download_with_ureq(url, progress)
162 }
163
164 #[allow(dead_code)]
166 fn download_with_ureq(&self, url: &str, progress: Option<ProgressCallback>) -> Result<Dataset> {
167 if self.config.use_cache {
169 let cache_key = format!("external_{}", blake3::hash(url.as_bytes()).to_hex());
170 if let Ok(cached_data) = self.cache.read_cached(&cache_key) {
171 return self.parse_cached_data(&cached_data);
172 }
173 }
174
175 let mut request = ureq::get(url).header("User-Agent", &self.config.user_agent);
176
177 for (key, value) in &self.config.headers {
179 request = request.header(key, value);
180 }
181
182 let response = request
183 .call()
184 .map_err(|e| DatasetsError::IoError(std::io::Error::other(e)))?;
185
186 let headers = response.headers();
188 let total_size = headers
189 .get("Content-Length")
190 .and_then(|hv| hv.to_str().ok())
191 .and_then(|s| s.parse::<u64>().ok())
192 .unwrap_or(0);
193
194 let mut body = response.into_body();
196 let buffer = body
197 .read_to_vec()
198 .map_err(|e| DatasetsError::IoError(std::io::Error::other(e)))?;
199 let downloaded = buffer.len() as u64;
200 if let Some(ref callback) = progress {
201 callback(downloaded, total_size);
202 }
203
204 if self.config.use_cache {
206 let cache_key = format!("external_{}", blake3::hash(url.as_bytes()).to_hex());
207 let _ = self.cache.put(&cache_key, &buffer);
208 }
209
210 self.parse_downloaded_data(url, &buffer)
212 }
213
214 #[cfg(feature = "download")]
215 async fn make_request(&self, url: &str) -> Result<reqwest::Response> {
216 let mut request = self.client.get(url);
217
218 for (key, value) in &self.config.headers {
220 request = request.header(key, value);
221 }
222
223 let mut last_error = None;
224
225 for attempt in 0..=self.config.max_retries {
226 match request
227 .try_clone()
228 .ok_or_else(|| {
229 DatasetsError::IoError(std::io::Error::other("Failed to clone request"))
230 })?
231 .send()
232 .await
233 {
234 Ok(response) => {
235 if response.status().is_success() {
236 return Ok(response);
237 } else {
238 last_error = Some(DatasetsError::IoError(std::io::Error::other(format!(
239 "HTTP {}: {}",
240 response.status(),
241 response.status().canonical_reason().unwrap_or("Unknown")
242 ))));
243 }
244 }
245 Err(e) => {
246 last_error = Some(DatasetsError::IoError(std::io::Error::other(e)));
247 }
248 }
249
250 if attempt < self.config.max_retries {
251 tokio::time::sleep(Duration::from_millis(1000 * (attempt + 1) as u64)).await;
252 }
253 }
254
255 Err(last_error.unwrap())
256 }
257
258 fn parse_cached_data(&self, data: &[u8]) -> Result<Dataset> {
259 if let Ok(dataset) = serde_json::from_slice::<Dataset>(data) {
261 return Ok(dataset);
262 }
263
264 self.parse_raw_data(data, None)
266 }
267
268 fn parse_downloaded_data(&self, url: &str, data: &[u8]) -> Result<Dataset> {
269 let extension = Path::new(url)
270 .extension()
271 .and_then(|s| s.to_str())
272 .unwrap_or("")
273 .to_lowercase();
274
275 self.parse_raw_data(data, Some(&extension))
276 }
277
278 fn parse_raw_data(&self, data: &[u8], extension: Option<&str>) -> Result<Dataset> {
279 match extension {
280 Some("csv") | None => {
281 let csv_data = String::from_utf8(data.to_vec())
283 .map_err(|e| DatasetsError::FormatError(format!("Invalid UTF-8: {e}")))?;
284
285 let temp_file = tempfile::NamedTempFile::new().map_err(DatasetsError::IoError)?;
287
288 std::fs::write(temp_file.path(), &csv_data).map_err(DatasetsError::IoError)?;
289
290 load_csv(temp_file.path(), CsvConfig::default())
291 }
292 Some("json") => {
293 let json_str = String::from_utf8(data.to_vec())
295 .map_err(|e| DatasetsError::FormatError(format!("Invalid UTF-8: {e}")))?;
296
297 serde_json::from_str(&json_str)
298 .map_err(|e| DatasetsError::FormatError(format!("Invalid JSON: {e}")))
299 }
300 Some("arff") => {
301 self.parse_arff_data(data)
303 }
304 _ => {
305 self.auto_detect_and_parse(data)
307 }
308 }
309 }
310
311 fn parse_arff_data(&self, data: &[u8]) -> Result<Dataset> {
312 let content = String::from_utf8(data.to_vec())
313 .map_err(|e| DatasetsError::FormatError(format!("Invalid UTF-8: {e}")))?;
314
315 let lines = content.lines();
316 let mut attributes = Vec::new();
317 let mut data_section = false;
318 let mut data_lines = Vec::new();
319
320 for line in lines {
321 let line = line.trim();
322
323 if line.is_empty() || line.starts_with('%') {
324 continue;
325 }
326
327 if line.to_lowercase().starts_with("@attribute") {
328 let parts: Vec<&str> = line.split_whitespace().collect();
329 if parts.len() >= 2 {
330 attributes.push(parts[1].to_string());
331 }
332 } else if line.to_lowercase().starts_with("@data") {
333 data_section = true;
334 } else if data_section {
335 data_lines.push(line.to_string());
336 }
337 }
338
339 let mut rows: Vec<Vec<f64>> = Vec::new();
341 for line in data_lines {
342 let values: Result<Vec<f64>> = line
343 .split(',')
344 .map(|s| {
345 s.trim()
346 .parse::<f64>()
347 .map_err(|_| DatasetsError::FormatError(format!("Invalid number: {s}")))
348 })
349 .collect();
350
351 match values {
352 Ok(row) => rows.push(row),
353 Err(_) => continue, }
355 }
356
357 if rows.is_empty() {
358 return Err(DatasetsError::FormatError(
359 "No valid data rows found".to_string(),
360 ));
361 }
362
363 let n_features = rows[0].len();
364 let n_samples = rows.len();
365
366 let (data_cols, target_col) = if n_features > 1 {
368 (n_features - 1, Some(n_features - 1))
369 } else {
370 (n_features, None)
371 };
372
373 let mut data_vec = Vec::with_capacity(n_samples * data_cols);
375 let mut target_vec = if target_col.is_some() {
376 Some(Vec::with_capacity(n_samples))
377 } else {
378 None
379 };
380
381 for row in rows {
382 for (i, &value) in row.iter().enumerate() {
383 if i < data_cols {
384 data_vec.push(value);
385 } else if let Some(ref mut targets) = target_vec {
386 targets.push(value);
387 }
388 }
389 }
390
391 let data = Array2::from_shape_vec((n_samples, data_cols), data_vec)
392 .map_err(|e| DatasetsError::FormatError(e.to_string()))?;
393
394 let target = target_vec.map(Array1::from_vec);
395
396 Ok(Dataset {
397 data,
398 target,
399 featurenames: Some(attributes[..data_cols].to_vec()),
400 targetnames: None,
401 feature_descriptions: None,
402 description: Some("ARFF dataset loaded from external source".to_string()),
403 metadata: std::collections::HashMap::new(),
404 })
405 }
406
407 fn auto_detect_and_parse(&self, data: &[u8]) -> Result<Dataset> {
408 let content = String::from_utf8(data.to_vec())
409 .map_err(|e| DatasetsError::FormatError(format!("Invalid UTF-8: {e}")))?;
410
411 if content.trim().starts_with('{') || content.trim().starts_with('[') {
413 if let Ok(dataset) = serde_json::from_str::<Dataset>(&content) {
414 return Ok(dataset);
415 }
416 }
417
418 if content.contains(',') || content.contains('\t') {
420 return self.parse_raw_data(data, Some("csv"));
421 }
422
423 if content.to_lowercase().contains("@relation") {
425 return self.parse_arff_data(data);
426 }
427
428 Err(DatasetsError::FormatError(
429 "Unable to auto-detect data format".to_string(),
430 ))
431 }
432}
433
434pub mod repositories {
436 use super::*;
437
438 pub struct UCIRepository {
440 client: ExternalClient,
441 base_url: String,
442 }
443
444 impl UCIRepository {
445 pub fn new() -> Result<Self> {
447 Ok(Self {
448 client: ExternalClient::new()?,
449 base_url: "https://archive.ics.uci.edu/ml/machine-learning-databases".to_string(),
450 })
451 }
452
453 #[cfg(feature = "download")]
461 pub async fn load_dataset(&self, name: &str) -> Result<Dataset> {
462 let url = match name {
463 "adult" => format!("{}/adult/adult.data", self.base_url),
464 "wine" => format!("{}/wine/wine.data", self.base_url),
465 "glass" => format!("{}/glass/glass.data", self.base_url),
466 "hepatitis" => format!("{}/hepatitis/hepatitis.data", self.base_url),
467 "heart-disease" => {
468 format!("{}/heart-disease/processed.cleveland.data", self.base_url)
469 }
470 _ => {
471 return Err(DatasetsError::NotFound(format!(
472 "UCI dataset '{name}' not found"
473 )))
474 }
475 };
476
477 self.client.download_dataset(&url, None).await
478 }
479
480 #[cfg(not(feature = "download"))]
481 pub fn load_dataset_sync(&self, name: &str) -> Result<Dataset> {
483 let url = match name {
484 "adult" => format!("{}/adult/adult.data", self.base_url),
485 "wine" => format!("{}/wine/wine.data", self.base_url),
486 "glass" => format!("{}/glass/glass.data", self.base_url),
487 "hepatitis" => format!("{}/hepatitis/hepatitis.data", self.base_url),
488 "heart-disease" => {
489 format!("{}/heart-disease/processed.cleveland.data", self.base_url)
490 }
491 _ => {
492 return Err(DatasetsError::NotFound(format!(
493 "UCI dataset '{name}' not found"
494 )))
495 }
496 };
497
498 self.client.download_dataset_sync(&url, None)
499 }
500
501 pub fn list_datasets(&self) -> Vec<&'static str> {
503 vec!["adult", "wine", "glass", "hepatitis", "heart-disease"]
504 }
505 }
506
507 pub struct KaggleRepository {
509 #[allow(dead_code)]
510 client: ExternalClient,
511 #[allow(dead_code)]
512 api_key: Option<String>,
513 }
514
515 impl KaggleRepository {
516 pub fn new(_apikey: Option<String>) -> Result<Self> {
518 let mut config = ExternalConfig::default();
519
520 if let Some(ref key) = _apikey {
521 config
522 .headers
523 .insert("Authorization".to_string(), format!("Bearer {key}"));
524 }
525
526 Ok(Self {
527 client: ExternalClient::with_config(config)?,
528 api_key: _apikey,
529 })
530 }
531
532 #[cfg(feature = "download")]
540 pub async fn load_competition_data(&self, competition: &str) -> Result<Dataset> {
541 if self.api_key.is_none() {
542 return Err(DatasetsError::AuthenticationError(
543 "Kaggle API key required".to_string(),
544 ));
545 }
546
547 let url = format!(
548 "https://www.kaggle.com/api/v1/competitions/{}/data/download",
549 competition
550 );
551 self.client.download_dataset(&url, None).await
552 }
553 }
554
555 pub struct GitHubRepository {
557 client: ExternalClient,
558 }
559
560 impl GitHubRepository {
561 pub fn new() -> Result<Self> {
563 Ok(Self {
564 client: ExternalClient::new()?,
565 })
566 }
567
568 #[cfg(feature = "download")]
578 pub async fn load_from_repo(&self, user: &str, repo: &str, path: &str) -> Result<Dataset> {
579 let url = format!("https://raw.githubusercontent.com/{user}/{repo}/main/{path}");
580 self.client.download_dataset(&url, None).await
581 }
582
583 #[cfg(not(feature = "download"))]
584 pub fn load_from_repo_sync(&self, user: &str, repo: &str, path: &str) -> Result<Dataset> {
586 let url = format!("https://raw.githubusercontent.com/{user}/{repo}/main/{path}");
587 self.client.download_dataset_sync(&url, None)
588 }
589 }
590}
591
592pub mod convenience {
594 use super::repositories::*;
595 use super::*;
596
597 #[cfg(feature = "download")]
599 pub async fn load_from_url(url: &str, config: Option<ExternalConfig>) -> Result<Dataset> {
600 let client = match config {
601 Some(cfg) => ExternalClient::with_config(cfg)?,
602 None => ExternalClient::new()?,
603 };
604
605 client
606 .download_dataset(
607 url,
608 Some(Box::new(|downloaded, total| {
609 if total > 0 {
610 let percent = (downloaded * 100) / total;
611 eprintln!("Downloaded: {percent:.1}% ({downloaded}/{total})");
612 } else {
613 eprintln!("Downloaded: {downloaded} bytes");
614 }
615 })),
616 )
617 .await
618 }
619
620 pub fn load_from_url_sync(url: &str, config: Option<ExternalConfig>) -> Result<Dataset> {
622 let client = match config {
623 Some(cfg) => ExternalClient::with_config(cfg)?,
624 None => ExternalClient::new()?,
625 };
626
627 client.download_dataset_sync(
628 url,
629 Some(Box::new(|downloaded, total| {
630 if total > 0 {
631 let percent = (downloaded * 100) / total;
632 eprintln!("Downloaded: {percent:.1}% ({downloaded}/{total})");
633 } else {
634 eprintln!("Downloaded: {downloaded} bytes");
635 }
636 })),
637 )
638 }
639
640 #[cfg(feature = "download")]
642 pub async fn load_uci_dataset(name: &str) -> Result<Dataset> {
643 let repo = UCIRepository::new()?;
644 repo.load_dataset(name).await
645 }
646
647 #[cfg(not(feature = "download"))]
649 pub fn load_uci_dataset_sync(name: &str) -> Result<Dataset> {
650 let repo = UCIRepository::new()?;
651 repo.load_dataset_sync(name)
652 }
653
654 #[cfg(feature = "download")]
656 pub async fn load_github_dataset(user: &str, repo: &str, path: &str) -> Result<Dataset> {
657 let github = GitHubRepository::new()?;
658 github.load_from_repo(user, repo, path).await
659 }
660
661 #[cfg(not(feature = "download"))]
663 pub fn load_github_dataset_sync(user: &str, repo: &str, path: &str) -> Result<Dataset> {
664 let github = GitHubRepository::new()?;
665 github.load_from_repo_sync(user, repo, path)
666 }
667
668 pub fn list_uci_datasets() -> Result<Vec<&'static str>> {
670 let repo = UCIRepository::new()?;
671 Ok(repo.list_datasets())
672 }
673}
674
675#[cfg(test)]
676mod tests {
677 use super::convenience::*;
678 use super::*;
679
680 #[test]
681 fn test_external_config_default() {
682 let config = ExternalConfig::default();
683 assert_eq!(config.timeout_seconds, 300);
684 assert_eq!(config.max_retries, 3);
685 assert!(config.verify_ssl);
686 assert!(config.use_cache);
687 }
688
689 #[test]
690 fn test_uci_repository_list_datasets() {
691 let datasets = list_uci_datasets().unwrap();
692 assert!(!datasets.is_empty());
693 assert!(datasets.contains(&"wine"));
694 assert!(datasets.contains(&"adult"));
695 }
696
697 #[test]
698 fn test_parse_arff_data() {
699 let arff_content = r#"
700@relation test
701@attribute feature1 numeric
702@attribute feature2 numeric
703@attribute class {0,1}
704@data
7051.0,2.0,0
7063.0,4.0,1
7075.0,6.0,0
708"#;
709
710 let client = ExternalClient::new().unwrap();
711 let dataset = client.parse_arff_data(arff_content.as_bytes()).unwrap();
712
713 assert_eq!(dataset.n_samples(), 3);
714 assert_eq!(dataset.n_features(), 2);
715 assert!(dataset.target.is_some());
716 }
717
718 #[tokio::test]
719 #[cfg(feature = "download")]
720 async fn test_download_small_csv() {
721 let url = "https://raw.githubusercontent.com/mwaskom/seaborn-data/master/iris.csv";
723
724 let result = load_from_url(url, None).await;
725 match result {
726 Ok(dataset) => {
727 assert!(dataset.n_samples() > 0);
728 assert!(dataset.n_features() > 0);
729 }
730 Err(e) => {
731 eprintln!("Network test failed (expected in CI): {}", e);
733 }
734 }
735 }
736}