scirs2_datasets/
external.rs

1//! External data sources integration
2//!
3//! This module provides functionality for loading datasets from external sources including:
4//! - URLs and web resources
5//! - API endpoints
6//! - Popular dataset repositories
7//! - Remote file systems
8
9use std::collections::HashMap;
10use std::io::Read;
11use std::path::Path;
12use std::time::Duration;
13
14use scirs2_core::ndarray::{Array1, Array2};
15use serde::{Deserialize, Serialize};
16
17use crate::cache::DatasetCache;
18use crate::error::{DatasetsError, Result};
19use crate::loaders::{load_csv, CsvConfig};
20use crate::utils::Dataset;
21
22/// Configuration for external data source access
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ExternalConfig {
25    /// Timeout for requests (in seconds)
26    pub timeout_seconds: u64,
27    /// Number of retry attempts
28    pub max_retries: u32,
29    /// User agent string for requests
30    pub user_agent: String,
31    /// Headers to include in requests
32    pub headers: HashMap<String, String>,
33    /// Whether to verify SSL certificates
34    pub verify_ssl: bool,
35    /// Cache downloaded files
36    pub use_cache: bool,
37}
38
39impl Default for ExternalConfig {
40    fn default() -> Self {
41        Self {
42            timeout_seconds: 300, // 5 minutes
43            max_retries: 3,
44            user_agent: "scirs2-datasets/0.1.0".to_string(),
45            headers: HashMap::new(),
46            verify_ssl: true,
47            use_cache: true,
48        }
49    }
50}
51
52/// Progress callback for download operations
53pub type ProgressCallback = Box<dyn Fn(u64, u64) + Send + Sync>;
54
55/// External data source client
56pub struct ExternalClient {
57    config: ExternalConfig,
58    cache: DatasetCache,
59    #[cfg(feature = "download")]
60    client: reqwest::Client,
61}
62
63impl ExternalClient {
64    /// Create a new external client with default configuration
65    pub fn new() -> Result<Self> {
66        Self::with_config(ExternalConfig::default())
67    }
68
69    /// Create a new external client with custom configuration
70    pub fn with_config(config: ExternalConfig) -> Result<Self> {
71        let cache = DatasetCache::new(crate::cache::get_cachedir()?);
72
73        #[cfg(feature = "download")]
74        let client = {
75            let mut builder = reqwest::Client::builder()
76                .timeout(Duration::from_secs(config.timeout_seconds))
77                .user_agent(&config.user_agent);
78
79            if !config.verify_ssl {
80                builder = builder.danger_accept_invalid_certs(true);
81            }
82
83            builder
84                .build()
85                .map_err(|e| DatasetsError::IoError(std::io::Error::other(e)))?
86        };
87
88        Ok(Self {
89            config,
90            cache,
91            #[cfg(feature = "download")]
92            client,
93        })
94    }
95
96    /// Download a dataset from a URL
97    #[cfg(feature = "download")]
98    pub async fn download_dataset(
99        &self,
100        url: &str,
101        progress: Option<ProgressCallback>,
102    ) -> Result<Dataset> {
103        // Check cache first
104        if self.config.use_cache {
105            let cache_key = format!("external_{}", blake3::hash(url.as_bytes()).to_hex());
106            if let Ok(cached_data) = self.cache.read_cached(&cache_key) {
107                return self.parse_cached_data(&cached_data);
108            }
109        }
110
111        // Download the file
112        let response = self.make_request(url).await?;
113        let total_size = response.content_length().unwrap_or(0);
114
115        let mut downloaded = 0u64;
116        let mut buffer = Vec::new();
117        let mut stream = response.bytes_stream();
118
119        use futures_util::StreamExt;
120        while let Some(chunk) = stream.next().await {
121            let chunk = chunk.map_err(|e| DatasetsError::IoError(std::io::Error::other(e)))?;
122            downloaded += chunk.len() as u64;
123            buffer.extend_from_slice(&chunk);
124
125            if let Some(ref callback) = progress {
126                callback(downloaded, total_size);
127            }
128        }
129
130        // Cache the downloaded data
131        if self.config.use_cache {
132            let cache_key = format!("external_{}", blake3::hash(url.as_bytes()).to_hex());
133            let _ = self.cache.put(&cache_key, &buffer);
134        }
135
136        // Parse the data based on content type or URL extension
137        self.parse_downloaded_data(url, &buffer)
138    }
139
140    /// Download a dataset synchronously (blocking) - when download feature is enabled
141    #[cfg(feature = "download")]
142    pub fn download_dataset_sync(
143        &self,
144        url: &str,
145        progress: Option<ProgressCallback>,
146    ) -> Result<Dataset> {
147        // Use tokio runtime to block on the async version
148        let rt = tokio::runtime::Runtime::new()
149            .map_err(|e| DatasetsError::IoError(std::io::Error::other(e)))?;
150        rt.block_on(self.download_dataset(url, progress))
151    }
152
153    /// Download a dataset synchronously (blocking) - fallback when download feature is disabled
154    #[cfg(not(feature = "download"))]
155    pub fn download_dataset_sync(
156        &self,
157        url: &str,
158        progress: Option<ProgressCallback>,
159    ) -> Result<Dataset> {
160        // Fallback implementation using ureq
161        self.download_with_ureq(url, progress)
162    }
163
164    /// Download using ureq (synchronous HTTP client)
165    #[allow(dead_code)]
166    fn download_with_ureq(&self, url: &str, progress: Option<ProgressCallback>) -> Result<Dataset> {
167        // Check cache first
168        if self.config.use_cache {
169            let cache_key = format!("external_{}", blake3::hash(url.as_bytes()).to_hex());
170            if let Ok(cached_data) = self.cache.read_cached(&cache_key) {
171                return self.parse_cached_data(&cached_data);
172            }
173        }
174
175        let mut request = ureq::get(url).header("User-Agent", &self.config.user_agent);
176
177        // Add custom headers
178        for (key, value) in &self.config.headers {
179            request = request.header(key, value);
180        }
181
182        let response = request
183            .call()
184            .map_err(|e| DatasetsError::IoError(std::io::Error::other(e)))?;
185
186        // Get content-length header if present (case-insensitive per HTTP spec)
187        let headers = response.headers();
188        let total_size = headers
189            .get("Content-Length")
190            .and_then(|hv| hv.to_str().ok())
191            .and_then(|s| s.parse::<u64>().ok())
192            .unwrap_or(0);
193
194        // Read body via body reader (ureq 3.x)
195        let mut body = response.into_body();
196        let buffer = body
197            .read_to_vec()
198            .map_err(|e| DatasetsError::IoError(std::io::Error::other(e)))?;
199        let downloaded = buffer.len() as u64;
200        if let Some(ref callback) = progress {
201            callback(downloaded, total_size);
202        }
203
204        // Cache the downloaded data
205        if self.config.use_cache {
206            let cache_key = format!("external_{}", blake3::hash(url.as_bytes()).to_hex());
207            let _ = self.cache.put(&cache_key, &buffer);
208        }
209
210        // Parse the data
211        self.parse_downloaded_data(url, &buffer)
212    }
213
214    #[cfg(feature = "download")]
215    async fn make_request(&self, url: &str) -> Result<reqwest::Response> {
216        let mut request = self.client.get(url);
217
218        // Add custom headers
219        for (key, value) in &self.config.headers {
220            request = request.header(key, value);
221        }
222
223        let mut last_error = None;
224
225        for attempt in 0..=self.config.max_retries {
226            match request
227                .try_clone()
228                .ok_or_else(|| {
229                    DatasetsError::IoError(std::io::Error::other("Failed to clone request"))
230                })?
231                .send()
232                .await
233            {
234                Ok(response) => {
235                    if response.status().is_success() {
236                        return Ok(response);
237                    } else {
238                        last_error = Some(DatasetsError::IoError(std::io::Error::other(format!(
239                            "HTTP {}: {}",
240                            response.status(),
241                            response.status().canonical_reason().unwrap_or("Unknown")
242                        ))));
243                    }
244                }
245                Err(e) => {
246                    last_error = Some(DatasetsError::IoError(std::io::Error::other(e)));
247                }
248            }
249
250            if attempt < self.config.max_retries {
251                tokio::time::sleep(Duration::from_millis(1000 * (attempt + 1) as u64)).await;
252            }
253        }
254
255        Err(last_error.unwrap())
256    }
257
258    fn parse_cached_data(&self, data: &[u8]) -> Result<Dataset> {
259        // Try to deserialize as JSON first (cached parsed data)
260        if let Ok(dataset) = serde_json::from_slice::<Dataset>(data) {
261            return Ok(dataset);
262        }
263
264        // Otherwise parse as raw data
265        self.parse_raw_data(data, None)
266    }
267
268    fn parse_downloaded_data(&self, url: &str, data: &[u8]) -> Result<Dataset> {
269        let extension = Path::new(url)
270            .extension()
271            .and_then(|s| s.to_str())
272            .unwrap_or("")
273            .to_lowercase();
274
275        self.parse_raw_data(data, Some(&extension))
276    }
277
278    fn parse_raw_data(&self, data: &[u8], extension: Option<&str>) -> Result<Dataset> {
279        match extension {
280            Some("csv") | None => {
281                // Try CSV parsing
282                let csv_data = String::from_utf8(data.to_vec())
283                    .map_err(|e| DatasetsError::FormatError(format!("Invalid UTF-8: {e}")))?;
284
285                // Write to temporary file for CSV parsing
286                let temp_file = tempfile::NamedTempFile::new().map_err(DatasetsError::IoError)?;
287
288                std::fs::write(temp_file.path(), &csv_data).map_err(DatasetsError::IoError)?;
289
290                load_csv(temp_file.path(), CsvConfig::default())
291            }
292            Some("json") => {
293                // Try JSON parsing
294                let json_str = String::from_utf8(data.to_vec())
295                    .map_err(|e| DatasetsError::FormatError(format!("Invalid UTF-8: {e}")))?;
296
297                serde_json::from_str(&json_str)
298                    .map_err(|e| DatasetsError::FormatError(format!("Invalid JSON: {e}")))
299            }
300            Some("arff") => {
301                // Basic ARFF parsing (simplified)
302                self.parse_arff_data(data)
303            }
304            _ => {
305                // Try to auto-detect format
306                self.auto_detect_and_parse(data)
307            }
308        }
309    }
310
311    fn parse_arff_data(&self, data: &[u8]) -> Result<Dataset> {
312        let content = String::from_utf8(data.to_vec())
313            .map_err(|e| DatasetsError::FormatError(format!("Invalid UTF-8: {e}")))?;
314
315        let lines = content.lines();
316        let mut attributes = Vec::new();
317        let mut data_section = false;
318        let mut data_lines = Vec::new();
319
320        for line in lines {
321            let line = line.trim();
322
323            if line.is_empty() || line.starts_with('%') {
324                continue;
325            }
326
327            if line.to_lowercase().starts_with("@attribute") {
328                let parts: Vec<&str> = line.split_whitespace().collect();
329                if parts.len() >= 2 {
330                    attributes.push(parts[1].to_string());
331                }
332            } else if line.to_lowercase().starts_with("@data") {
333                data_section = true;
334            } else if data_section {
335                data_lines.push(line.to_string());
336            }
337        }
338
339        // Parse data rows
340        let mut rows: Vec<Vec<f64>> = Vec::new();
341        for line in data_lines {
342            let values: Result<Vec<f64>> = line
343                .split(',')
344                .map(|s| {
345                    s.trim()
346                        .parse::<f64>()
347                        .map_err(|_| DatasetsError::FormatError(format!("Invalid number: {s}")))
348                })
349                .collect();
350
351            match values {
352                Ok(row) => rows.push(row),
353                Err(_) => continue, // Skip invalid rows
354            }
355        }
356
357        if rows.is_empty() {
358            return Err(DatasetsError::FormatError(
359                "No valid data rows found".to_string(),
360            ));
361        }
362
363        let n_features = rows[0].len();
364        let n_samples = rows.len();
365
366        // Assume last column is target if more than one column
367        let (data_cols, target_col) = if n_features > 1 {
368            (n_features - 1, Some(n_features - 1))
369        } else {
370            (n_features, None)
371        };
372
373        // Create data array
374        let mut data_vec = Vec::with_capacity(n_samples * data_cols);
375        let mut target_vec = if target_col.is_some() {
376            Some(Vec::with_capacity(n_samples))
377        } else {
378            None
379        };
380
381        for row in rows {
382            for (i, &value) in row.iter().enumerate() {
383                if i < data_cols {
384                    data_vec.push(value);
385                } else if let Some(ref mut targets) = target_vec {
386                    targets.push(value);
387                }
388            }
389        }
390
391        let data = Array2::from_shape_vec((n_samples, data_cols), data_vec)
392            .map_err(|e| DatasetsError::FormatError(e.to_string()))?;
393
394        let target = target_vec.map(Array1::from_vec);
395
396        Ok(Dataset {
397            data,
398            target,
399            featurenames: Some(attributes[..data_cols].to_vec()),
400            targetnames: None,
401            feature_descriptions: None,
402            description: Some("ARFF dataset loaded from external source".to_string()),
403            metadata: std::collections::HashMap::new(),
404        })
405    }
406
407    fn auto_detect_and_parse(&self, data: &[u8]) -> Result<Dataset> {
408        let content = String::from_utf8(data.to_vec())
409            .map_err(|e| DatasetsError::FormatError(format!("Invalid UTF-8: {e}")))?;
410
411        // Try JSON first
412        if content.trim().starts_with('{') || content.trim().starts_with('[') {
413            if let Ok(dataset) = serde_json::from_str::<Dataset>(&content) {
414                return Ok(dataset);
415            }
416        }
417
418        // Try CSV
419        if content.contains(',') || content.contains('\t') {
420            return self.parse_raw_data(data, Some("csv"));
421        }
422
423        // Try ARFF
424        if content.to_lowercase().contains("@relation") {
425            return self.parse_arff_data(data);
426        }
427
428        Err(DatasetsError::FormatError(
429            "Unable to auto-detect data format".to_string(),
430        ))
431    }
432}
433
434/// Popular dataset repository APIs
435pub mod repositories {
436    use super::*;
437
438    /// UCI Machine Learning Repository client
439    pub struct UCIRepository {
440        client: ExternalClient,
441        base_url: String,
442    }
443
444    impl UCIRepository {
445        /// Create a new UCI repository client
446        pub fn new() -> Result<Self> {
447            Ok(Self {
448                client: ExternalClient::new()?,
449                base_url: "https://archive.ics.uci.edu/ml/machine-learning-databases".to_string(),
450            })
451        }
452
453        /// Loads a dataset from the UCI Machine Learning Repository.
454        ///
455        /// # Arguments
456        /// * `name` - The name of the dataset to load
457        ///
458        /// # Returns
459        /// A `Dataset` containing the loaded data
460        #[cfg(feature = "download")]
461        pub async fn load_dataset(&self, name: &str) -> Result<Dataset> {
462            let url = match name {
463                "adult" => format!("{}/adult/adult.data", self.base_url),
464                "wine" => format!("{}/wine/wine.data", self.base_url),
465                "glass" => format!("{}/glass/glass.data", self.base_url),
466                "hepatitis" => format!("{}/hepatitis/hepatitis.data", self.base_url),
467                "heart-disease" => {
468                    format!("{}/heart-disease/processed.cleveland.data", self.base_url)
469                }
470                _ => {
471                    return Err(DatasetsError::NotFound(format!(
472                        "UCI dataset '{name}' not found"
473                    )))
474                }
475            };
476
477            self.client.download_dataset(&url, None).await
478        }
479
480        #[cfg(not(feature = "download"))]
481        /// Load a UCI dataset synchronously
482        pub fn load_dataset_sync(&self, name: &str) -> Result<Dataset> {
483            let url = match name {
484                "adult" => format!("{}/adult/adult.data", self.base_url),
485                "wine" => format!("{}/wine/wine.data", self.base_url),
486                "glass" => format!("{}/glass/glass.data", self.base_url),
487                "hepatitis" => format!("{}/hepatitis/hepatitis.data", self.base_url),
488                "heart-disease" => {
489                    format!("{}/heart-disease/processed.cleveland.data", self.base_url)
490                }
491                _ => {
492                    return Err(DatasetsError::NotFound(format!(
493                        "UCI dataset '{name}' not found"
494                    )))
495                }
496            };
497
498            self.client.download_dataset_sync(&url, None)
499        }
500
501        /// List available UCI datasets
502        pub fn list_datasets(&self) -> Vec<&'static str> {
503            vec!["adult", "wine", "glass", "hepatitis", "heart-disease"]
504        }
505    }
506
507    /// Kaggle dataset client (requires API key)
508    pub struct KaggleRepository {
509        #[allow(dead_code)]
510        client: ExternalClient,
511        #[allow(dead_code)]
512        api_key: Option<String>,
513    }
514
515    impl KaggleRepository {
516        /// Create a new Kaggle repository client
517        pub fn new(_apikey: Option<String>) -> Result<Self> {
518            let mut config = ExternalConfig::default();
519
520            if let Some(ref key) = _apikey {
521                config
522                    .headers
523                    .insert("Authorization".to_string(), format!("Bearer {key}"));
524            }
525
526            Ok(Self {
527                client: ExternalClient::with_config(config)?,
528                api_key: _apikey,
529            })
530        }
531
532        /// Loads competition data from Kaggle.
533        ///
534        /// # Arguments
535        /// * `competition` - The name of the Kaggle competition
536        ///
537        /// # Returns
538        /// A `Dataset` containing the competition data
539        #[cfg(feature = "download")]
540        pub async fn load_competition_data(&self, competition: &str) -> Result<Dataset> {
541            if self.api_key.is_none() {
542                return Err(DatasetsError::AuthenticationError(
543                    "Kaggle API key required".to_string(),
544                ));
545            }
546
547            let url = format!(
548                "https://www.kaggle.com/api/v1/competitions/{}/data/download",
549                competition
550            );
551            self.client.download_dataset(&url, None).await
552        }
553    }
554
555    /// GitHub repository client for datasets
556    pub struct GitHubRepository {
557        client: ExternalClient,
558    }
559
560    impl GitHubRepository {
561        /// Create a new GitHub repository client
562        pub fn new() -> Result<Self> {
563            Ok(Self {
564                client: ExternalClient::new()?,
565            })
566        }
567
568        /// Loads a dataset from a GitHub repository.
569        ///
570        /// # Arguments
571        /// * `user` - The GitHub username
572        /// * `repo` - The repository name
573        /// * `path` - The path to the dataset file within the repository
574        ///
575        /// # Returns
576        /// A `Dataset` containing the loaded data
577        #[cfg(feature = "download")]
578        pub async fn load_from_repo(&self, user: &str, repo: &str, path: &str) -> Result<Dataset> {
579            let url = format!("https://raw.githubusercontent.com/{user}/{repo}/main/{path}");
580            self.client.download_dataset(&url, None).await
581        }
582
583        #[cfg(not(feature = "download"))]
584        /// Load a dataset from GitHub repository synchronously
585        pub fn load_from_repo_sync(&self, user: &str, repo: &str, path: &str) -> Result<Dataset> {
586            let url = format!("https://raw.githubusercontent.com/{user}/{repo}/main/{path}");
587            self.client.download_dataset_sync(&url, None)
588        }
589    }
590}
591
592/// Convenience functions for common external data operations
593pub mod convenience {
594    use super::repositories::*;
595    use super::*;
596
597    /// Load a dataset from a URL with progress tracking
598    #[cfg(feature = "download")]
599    pub async fn load_from_url(url: &str, config: Option<ExternalConfig>) -> Result<Dataset> {
600        let client = match config {
601            Some(cfg) => ExternalClient::with_config(cfg)?,
602            None => ExternalClient::new()?,
603        };
604
605        client
606            .download_dataset(
607                url,
608                Some(Box::new(|downloaded, total| {
609                    if total > 0 {
610                        let percent = (downloaded * 100) / total;
611                        eprintln!("Downloaded: {percent:.1}% ({downloaded}/{total})");
612                    } else {
613                        eprintln!("Downloaded: {downloaded} bytes");
614                    }
615                })),
616            )
617            .await
618    }
619
620    /// Load a dataset from a URL synchronously
621    pub fn load_from_url_sync(url: &str, config: Option<ExternalConfig>) -> Result<Dataset> {
622        let client = match config {
623            Some(cfg) => ExternalClient::with_config(cfg)?,
624            None => ExternalClient::new()?,
625        };
626
627        client.download_dataset_sync(
628            url,
629            Some(Box::new(|downloaded, total| {
630                if total > 0 {
631                    let percent = (downloaded * 100) / total;
632                    eprintln!("Downloaded: {percent:.1}% ({downloaded}/{total})");
633                } else {
634                    eprintln!("Downloaded: {downloaded} bytes");
635                }
636            })),
637        )
638    }
639
640    /// Load a UCI dataset by name
641    #[cfg(feature = "download")]
642    pub async fn load_uci_dataset(name: &str) -> Result<Dataset> {
643        let repo = UCIRepository::new()?;
644        repo.load_dataset(name).await
645    }
646
647    /// Load a UCI dataset by name synchronously
648    #[cfg(not(feature = "download"))]
649    pub fn load_uci_dataset_sync(name: &str) -> Result<Dataset> {
650        let repo = UCIRepository::new()?;
651        repo.load_dataset_sync(name)
652    }
653
654    /// Load a dataset from GitHub repository
655    #[cfg(feature = "download")]
656    pub async fn load_github_dataset(user: &str, repo: &str, path: &str) -> Result<Dataset> {
657        let github = GitHubRepository::new()?;
658        github.load_from_repo(user, repo, path).await
659    }
660
661    /// Load a dataset from GitHub repository synchronously
662    #[cfg(not(feature = "download"))]
663    pub fn load_github_dataset_sync(user: &str, repo: &str, path: &str) -> Result<Dataset> {
664        let github = GitHubRepository::new()?;
665        github.load_from_repo_sync(user, repo, path)
666    }
667
668    /// List available UCI datasets
669    pub fn list_uci_datasets() -> Result<Vec<&'static str>> {
670        let repo = UCIRepository::new()?;
671        Ok(repo.list_datasets())
672    }
673}
674
675#[cfg(test)]
676mod tests {
677    use super::convenience::*;
678    use super::*;
679
680    #[test]
681    fn test_external_config_default() {
682        let config = ExternalConfig::default();
683        assert_eq!(config.timeout_seconds, 300);
684        assert_eq!(config.max_retries, 3);
685        assert!(config.verify_ssl);
686        assert!(config.use_cache);
687    }
688
689    #[test]
690    fn test_uci_repository_list_datasets() {
691        let datasets = list_uci_datasets().unwrap();
692        assert!(!datasets.is_empty());
693        assert!(datasets.contains(&"wine"));
694        assert!(datasets.contains(&"adult"));
695    }
696
697    #[test]
698    fn test_parse_arff_data() {
699        let arff_content = r#"
700@relation test
701@attribute feature1 numeric
702@attribute feature2 numeric
703@attribute class {0,1}
704@data
7051.0,2.0,0
7063.0,4.0,1
7075.0,6.0,0
708"#;
709
710        let client = ExternalClient::new().unwrap();
711        let dataset = client.parse_arff_data(arff_content.as_bytes()).unwrap();
712
713        assert_eq!(dataset.n_samples(), 3);
714        assert_eq!(dataset.n_features(), 2);
715        assert!(dataset.target.is_some());
716    }
717
718    #[tokio::test]
719    #[cfg(feature = "download")]
720    async fn test_download_small_csv() {
721        // Test with a small public CSV dataset
722        let url = "https://raw.githubusercontent.com/mwaskom/seaborn-data/master/iris.csv";
723
724        let result = load_from_url(url, None).await;
725        match result {
726            Ok(dataset) => {
727                assert!(dataset.n_samples() > 0);
728                assert!(dataset.n_features() > 0);
729            }
730            Err(e) => {
731                // Network tests may fail in CI, so we just log the error
732                eprintln!("Network test failed (expected in CI): {}", e);
733            }
734        }
735    }
736}