sklears_utils/
cloud_storage.rs

1//! Cloud storage utilities for machine learning data processing
2//!
3//! This module provides unified interfaces for working with cloud storage services
4//! including AWS S3, Google Cloud Storage, and Azure Blob Storage.
5
6use crate::{UtilsError, UtilsResult};
7use std::collections::HashMap;
8use std::fmt;
9
10/// Cloud storage configuration
11#[derive(Debug, Clone)]
12pub struct CloudStorageConfig {
13    pub provider: CloudProvider,
14    pub endpoint: Option<String>,
15    pub region: Option<String>,
16    pub access_key: Option<String>,
17    pub secret_key: Option<String>,
18    pub bucket: String,
19    pub timeout_seconds: Option<u64>,
20    pub use_ssl: bool,
21    pub custom_headers: HashMap<String, String>,
22}
23
24/// Supported cloud storage providers
25#[derive(Debug, Clone, PartialEq)]
26pub enum CloudProvider {
27    AWS,
28    GoogleCloud,
29    Azure,
30    MinIO,
31    Custom(String),
32}
33
34impl fmt::Display for CloudProvider {
35    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36        match self {
37            CloudProvider::AWS => write!(f, "aws"),
38            CloudProvider::GoogleCloud => write!(f, "gcp"),
39            CloudProvider::Azure => write!(f, "azure"),
40            CloudProvider::MinIO => write!(f, "minio"),
41            CloudProvider::Custom(name) => write!(f, "{name}"),
42        }
43    }
44}
45
46impl Default for CloudStorageConfig {
47    fn default() -> Self {
48        Self {
49            provider: CloudProvider::AWS,
50            endpoint: None,
51            region: Some("us-east-1".to_string()),
52            access_key: None,
53            secret_key: None,
54            bucket: String::new(),
55            timeout_seconds: Some(30),
56            use_ssl: true,
57            custom_headers: HashMap::new(),
58        }
59    }
60}
61
62/// Cloud storage client trait
63pub trait CloudStorageClient {
64    /// Upload data to cloud storage
65    fn upload(&self, key: &str, data: &[u8]) -> UtilsResult<String>;
66
67    /// Download data from cloud storage
68    fn download(&self, key: &str) -> UtilsResult<Vec<u8>>;
69
70    /// Delete object from cloud storage
71    fn delete(&self, key: &str) -> UtilsResult<()>;
72
73    /// List objects with prefix
74    fn list_objects(&self, prefix: &str) -> UtilsResult<Vec<String>>;
75
76    /// Check if object exists
77    fn exists(&self, key: &str) -> UtilsResult<bool>;
78
79    /// Get object metadata
80    fn get_metadata(&self, key: &str) -> UtilsResult<ObjectMetadata>;
81
82    /// Upload file from local path
83    fn upload_file(&self, key: &str, local_path: &str) -> UtilsResult<String>;
84
85    /// Download file to local path
86    fn download_file(&self, key: &str, local_path: &str) -> UtilsResult<()>;
87}
88
89/// Object metadata
90#[derive(Debug, Clone)]
91pub struct ObjectMetadata {
92    pub size: u64,
93    pub etag: Option<String>,
94    pub content_type: Option<String>,
95    pub last_modified: Option<String>,
96    pub custom_metadata: HashMap<String, String>,
97}
98
99/// Mock cloud storage client for testing
100pub struct MockCloudStorageClient {
101    storage: std::sync::Arc<std::sync::Mutex<HashMap<String, Vec<u8>>>>,
102    metadata: std::sync::Arc<std::sync::Mutex<HashMap<String, ObjectMetadata>>>,
103}
104
105impl Default for MockCloudStorageClient {
106    fn default() -> Self {
107        Self::new()
108    }
109}
110
111impl MockCloudStorageClient {
112    pub fn new() -> Self {
113        Self {
114            storage: std::sync::Arc::new(std::sync::Mutex::new(HashMap::new())),
115            metadata: std::sync::Arc::new(std::sync::Mutex::new(HashMap::new())),
116        }
117    }
118}
119
120impl CloudStorageClient for MockCloudStorageClient {
121    fn upload(&self, key: &str, data: &[u8]) -> UtilsResult<String> {
122        let mut storage = self.storage.lock().unwrap();
123        let mut metadata = self.metadata.lock().unwrap();
124
125        storage.insert(key.to_string(), data.to_vec());
126        metadata.insert(
127            key.to_string(),
128            ObjectMetadata {
129                size: data.len() as u64,
130                etag: Some(format!("mock-etag-{key}")),
131                content_type: Some("application/octet-stream".to_string()),
132                last_modified: Some(chrono::Utc::now().to_rfc3339()),
133                custom_metadata: HashMap::new(),
134            },
135        );
136
137        Ok(format!("mock://bucket/{key}"))
138    }
139
140    fn download(&self, key: &str) -> UtilsResult<Vec<u8>> {
141        let storage = self.storage.lock().unwrap();
142        storage
143            .get(key)
144            .cloned()
145            .ok_or_else(|| UtilsError::InvalidParameter(format!("Object not found: {key}")))
146    }
147
148    fn delete(&self, key: &str) -> UtilsResult<()> {
149        let mut storage = self.storage.lock().unwrap();
150        let mut metadata = self.metadata.lock().unwrap();
151
152        storage.remove(key);
153        metadata.remove(key);
154        Ok(())
155    }
156
157    fn list_objects(&self, prefix: &str) -> UtilsResult<Vec<String>> {
158        let storage = self.storage.lock().unwrap();
159        let objects: Vec<String> = storage
160            .keys()
161            .filter(|key| key.starts_with(prefix))
162            .cloned()
163            .collect();
164        Ok(objects)
165    }
166
167    fn exists(&self, key: &str) -> UtilsResult<bool> {
168        let storage = self.storage.lock().unwrap();
169        Ok(storage.contains_key(key))
170    }
171
172    fn get_metadata(&self, key: &str) -> UtilsResult<ObjectMetadata> {
173        let metadata = self.metadata.lock().unwrap();
174        metadata
175            .get(key)
176            .cloned()
177            .ok_or_else(|| UtilsError::InvalidParameter(format!("Object not found: {key}")))
178    }
179
180    fn upload_file(&self, key: &str, local_path: &str) -> UtilsResult<String> {
181        let data = std::fs::read(local_path)
182            .map_err(|e| UtilsError::InvalidParameter(format!("Failed to read file: {e}")))?;
183        self.upload(key, &data)
184    }
185
186    fn download_file(&self, key: &str, local_path: &str) -> UtilsResult<()> {
187        let data = self.download(key)?;
188        std::fs::write(local_path, data)
189            .map_err(|e| UtilsError::InvalidParameter(format!("Failed to write file: {e}")))?;
190        Ok(())
191    }
192}
193
194/// Cloud storage factory
195pub struct CloudStorageFactory;
196
197impl CloudStorageFactory {
198    /// Create a cloud storage client based on configuration
199    pub fn create_client(config: &CloudStorageConfig) -> UtilsResult<Box<dyn CloudStorageClient>> {
200        match config.provider {
201            CloudProvider::AWS => {
202                // In a real implementation, this would create an AWS S3 client
203                // For now, we'll use the mock client
204                Ok(Box::new(MockCloudStorageClient::new()))
205            }
206            CloudProvider::GoogleCloud => {
207                // In a real implementation, this would create a GCS client
208                Ok(Box::new(MockCloudStorageClient::new()))
209            }
210            CloudProvider::Azure => {
211                // In a real implementation, this would create an Azure Blob client
212                Ok(Box::new(MockCloudStorageClient::new()))
213            }
214            CloudProvider::MinIO => {
215                // In a real implementation, this would create a MinIO client
216                Ok(Box::new(MockCloudStorageClient::new()))
217            }
218            CloudProvider::Custom(_) => {
219                // For custom providers, use mock client
220                Ok(Box::new(MockCloudStorageClient::new()))
221            }
222        }
223    }
224}
225
226/// Cloud storage utilities for ML data processing
227pub struct CloudStorageUtils;
228
229impl CloudStorageUtils {
230    /// Upload ML dataset to cloud storage
231    pub fn upload_dataset(
232        client: &dyn CloudStorageClient,
233        dataset_path: &str,
234        key_prefix: &str,
235    ) -> UtilsResult<Vec<String>> {
236        let mut uploaded_keys = Vec::new();
237
238        // Read dataset directory
239        let entries = std::fs::read_dir(dataset_path)
240            .map_err(|e| UtilsError::InvalidParameter(format!("Failed to read directory: {e}")))?;
241
242        for entry in entries {
243            let entry = entry
244                .map_err(|e| UtilsError::InvalidParameter(format!("Failed to read entry: {e}")))?;
245            let path = entry.path();
246
247            if path.is_file() {
248                let filename = path.file_name().unwrap().to_str().unwrap();
249                let key = format!("{key_prefix}/{filename}");
250                let local_path = path.to_str().unwrap();
251
252                client.upload_file(&key, local_path)?;
253                uploaded_keys.push(key);
254            }
255        }
256
257        Ok(uploaded_keys)
258    }
259
260    /// Download ML dataset from cloud storage
261    pub fn download_dataset(
262        client: &dyn CloudStorageClient,
263        key_prefix: &str,
264        local_path: &str,
265    ) -> UtilsResult<Vec<String>> {
266        let objects = client.list_objects(key_prefix)?;
267        let mut downloaded_files = Vec::new();
268
269        // Create local directory if it doesn't exist
270        std::fs::create_dir_all(local_path).map_err(|e| {
271            UtilsError::InvalidParameter(format!("Failed to create directory: {e}"))
272        })?;
273
274        for object_key in objects {
275            let filename = object_key.split('/').next_back().unwrap_or(&object_key);
276            let local_file_path = format!("{local_path}/{filename}");
277
278            client.download_file(&object_key, &local_file_path)?;
279            downloaded_files.push(local_file_path);
280        }
281
282        Ok(downloaded_files)
283    }
284
285    /// Sync local dataset with cloud storage
286    pub fn sync_dataset(
287        client: &dyn CloudStorageClient,
288        local_path: &str,
289        key_prefix: &str,
290        sync_mode: SyncMode,
291    ) -> UtilsResult<SyncResult> {
292        let mut result = SyncResult::default();
293
294        match sync_mode {
295            SyncMode::Upload => {
296                let uploaded = Self::upload_dataset(client, local_path, key_prefix)?;
297                result.uploaded = uploaded;
298            }
299            SyncMode::Download => {
300                let downloaded = Self::download_dataset(client, key_prefix, local_path)?;
301                result.downloaded = downloaded;
302            }
303            SyncMode::Bidirectional => {
304                // Simple bidirectional sync: upload first, then download
305                let uploaded = Self::upload_dataset(client, local_path, key_prefix)?;
306                let downloaded = Self::download_dataset(client, key_prefix, local_path)?;
307                result.uploaded = uploaded;
308                result.downloaded = downloaded;
309            }
310        }
311
312        Ok(result)
313    }
314
315    /// Batch upload multiple files with metadata
316    pub fn batch_upload(
317        client: &dyn CloudStorageClient,
318        files: &[(String, String)], // (local_path, key)
319    ) -> UtilsResult<Vec<String>> {
320        let mut uploaded_keys = Vec::new();
321
322        for (local_path, key) in files {
323            let result = client.upload_file(key, local_path)?;
324            uploaded_keys.push(result);
325        }
326
327        Ok(uploaded_keys)
328    }
329
330    /// Calculate storage metrics for ML datasets
331    pub fn calculate_storage_metrics(
332        client: &dyn CloudStorageClient,
333        key_prefix: &str,
334    ) -> UtilsResult<StorageMetrics> {
335        let objects = client.list_objects(key_prefix)?;
336        let mut total_size = 0;
337        let mut total_objects = 0;
338        let mut file_types = HashMap::new();
339
340        for object_key in objects {
341            if let Ok(metadata) = client.get_metadata(&object_key) {
342                total_size += metadata.size;
343                total_objects += 1;
344
345                // Extract file extension
346                if let Some(ext) = object_key.split('.').next_back() {
347                    *file_types.entry(ext.to_string()).or_insert(0) += 1;
348                }
349            }
350        }
351
352        Ok(StorageMetrics {
353            total_size_bytes: total_size,
354            total_objects,
355            file_types,
356            average_file_size: if total_objects > 0 {
357                total_size / total_objects
358            } else {
359                0
360            },
361        })
362    }
363}
364
365/// Sync mode for dataset synchronization
366#[derive(Debug, Clone)]
367pub enum SyncMode {
368    Upload,
369    Download,
370    Bidirectional,
371}
372
373/// Sync result
374#[derive(Debug, Clone, Default)]
375pub struct SyncResult {
376    pub uploaded: Vec<String>,
377    pub downloaded: Vec<String>,
378    pub errors: Vec<String>,
379}
380
381/// Storage metrics
382#[derive(Debug, Clone)]
383pub struct StorageMetrics {
384    pub total_size_bytes: u64,
385    pub total_objects: u64,
386    pub file_types: HashMap<String, usize>,
387    pub average_file_size: u64,
388}
389
390impl fmt::Display for StorageMetrics {
391    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
392        writeln!(f, "Storage Metrics:")?;
393        writeln!(
394            f,
395            "  Total Size: {:.2} MB",
396            self.total_size_bytes as f64 / 1024.0 / 1024.0
397        )?;
398        writeln!(f, "  Total Objects: {}", self.total_objects)?;
399        writeln!(
400            f,
401            "  Average File Size: {:.2} KB",
402            self.average_file_size as f64 / 1024.0
403        )?;
404        writeln!(f, "  File Types:")?;
405        for (ext, count) in &self.file_types {
406            writeln!(f, "    .{ext}: {count}")?;
407        }
408        Ok(())
409    }
410}
411
412#[allow(non_snake_case)]
413#[cfg(test)]
414mod tests {
415    use super::*;
416    use std::fs;
417
418    #[test]
419    fn test_cloud_storage_config() {
420        let config = CloudStorageConfig {
421            provider: CloudProvider::AWS,
422            bucket: "test-bucket".to_string(),
423            ..Default::default()
424        };
425
426        assert_eq!(config.provider, CloudProvider::AWS);
427        assert_eq!(config.bucket, "test-bucket");
428        assert_eq!(config.region, Some("us-east-1".to_string()));
429        assert!(config.use_ssl);
430    }
431
432    #[test]
433    fn test_cloud_provider_display() {
434        assert_eq!(CloudProvider::AWS.to_string(), "aws");
435        assert_eq!(CloudProvider::GoogleCloud.to_string(), "gcp");
436        assert_eq!(CloudProvider::Azure.to_string(), "azure");
437        assert_eq!(CloudProvider::MinIO.to_string(), "minio");
438        assert_eq!(
439            CloudProvider::Custom("test".to_string()).to_string(),
440            "test"
441        );
442    }
443
444    #[test]
445    fn test_mock_client_upload_download() {
446        let client = MockCloudStorageClient::new();
447        let test_data = b"hello world";
448
449        // Test upload
450        let url = client.upload("test-key", test_data).unwrap();
451        assert_eq!(url, "mock://bucket/test-key");
452
453        // Test download
454        let downloaded = client.download("test-key").unwrap();
455        assert_eq!(downloaded, test_data);
456
457        // Test exists
458        assert!(client.exists("test-key").unwrap());
459        assert!(!client.exists("nonexistent-key").unwrap());
460    }
461
462    #[test]
463    fn test_mock_client_metadata() {
464        let client = MockCloudStorageClient::new();
465        let test_data = b"hello world";
466
467        client.upload("test-key", test_data).unwrap();
468
469        let metadata = client.get_metadata("test-key").unwrap();
470        assert_eq!(metadata.size, test_data.len() as u64);
471        assert_eq!(metadata.etag, Some("mock-etag-test-key".to_string()));
472        assert_eq!(
473            metadata.content_type,
474            Some("application/octet-stream".to_string())
475        );
476    }
477
478    #[test]
479    fn test_mock_client_list_objects() {
480        let client = MockCloudStorageClient::new();
481
482        client.upload("data/file1.txt", b"content1").unwrap();
483        client.upload("data/file2.txt", b"content2").unwrap();
484        client.upload("other/file3.txt", b"content3").unwrap();
485
486        let objects = client.list_objects("data/").unwrap();
487        assert_eq!(objects.len(), 2);
488        assert!(objects.contains(&"data/file1.txt".to_string()));
489        assert!(objects.contains(&"data/file2.txt".to_string()));
490    }
491
492    #[test]
493    fn test_mock_client_delete() {
494        let client = MockCloudStorageClient::new();
495
496        client.upload("test-key", b"hello").unwrap();
497        assert!(client.exists("test-key").unwrap());
498
499        client.delete("test-key").unwrap();
500        assert!(!client.exists("test-key").unwrap());
501    }
502
503    #[test]
504    fn test_cloud_storage_factory() {
505        let config = CloudStorageConfig {
506            provider: CloudProvider::AWS,
507            bucket: "test-bucket".to_string(),
508            ..Default::default()
509        };
510
511        let client = CloudStorageFactory::create_client(&config).unwrap();
512
513        // Test that we can use the client
514        client.upload("test", b"data").unwrap();
515        let downloaded = client.download("test").unwrap();
516        assert_eq!(downloaded, b"data");
517    }
518
519    #[test]
520    fn test_storage_metrics_display() {
521        let mut file_types = HashMap::new();
522        file_types.insert("txt".to_string(), 5);
523        file_types.insert("csv".to_string(), 3);
524
525        let metrics = StorageMetrics {
526            total_size_bytes: 1_048_576, // 1 MB
527            total_objects: 8,
528            file_types,
529            average_file_size: 131_072, // 128 KB
530        };
531
532        let display = metrics.to_string();
533        assert!(display.contains("Total Size: 1.00 MB"));
534        assert!(display.contains("Total Objects: 8"));
535        assert!(display.contains("Average File Size: 128.00 KB"));
536        assert!(display.contains(".txt: 5"));
537        assert!(display.contains(".csv: 3"));
538    }
539
540    #[test]
541    fn test_sync_result_default() {
542        let result = SyncResult::default();
543        assert!(result.uploaded.is_empty());
544        assert!(result.downloaded.is_empty());
545        assert!(result.errors.is_empty());
546    }
547
548    #[test]
549    fn test_file_upload_download() {
550        let client = MockCloudStorageClient::new();
551        let temp_dir = tempfile::tempdir().unwrap();
552        let file_path = temp_dir.path().join("test.txt");
553
554        // Create test file
555        fs::write(&file_path, b"test content").unwrap();
556
557        // Upload file
558        let url = client
559            .upload_file("test.txt", file_path.to_str().unwrap())
560            .unwrap();
561        assert_eq!(url, "mock://bucket/test.txt");
562
563        // Download file
564        let download_path = temp_dir.path().join("downloaded.txt");
565        client
566            .download_file("test.txt", download_path.to_str().unwrap())
567            .unwrap();
568
569        // Verify content
570        let downloaded_content = fs::read(&download_path).unwrap();
571        assert_eq!(downloaded_content, b"test content");
572    }
573
574    #[test]
575    fn test_calculate_storage_metrics() {
576        let client = MockCloudStorageClient::new();
577
578        // Upload test files
579        client.upload("data/file1.txt", b"hello").unwrap();
580        client.upload("data/file2.csv", b"world").unwrap();
581        client.upload("data/file3.txt", b"test").unwrap();
582
583        let metrics = CloudStorageUtils::calculate_storage_metrics(&client, "data/").unwrap();
584
585        assert_eq!(metrics.total_objects, 3);
586        assert_eq!(metrics.total_size_bytes, 14); // 5 + 5 + 4
587        assert_eq!(metrics.file_types.get("txt"), Some(&2));
588        assert_eq!(metrics.file_types.get("csv"), Some(&1));
589    }
590
591    #[test]
592    fn test_batch_upload() {
593        let client = MockCloudStorageClient::new();
594        let temp_dir = tempfile::tempdir().unwrap();
595
596        // Create test files
597        let file1_path = temp_dir.path().join("file1.txt");
598        let file2_path = temp_dir.path().join("file2.txt");
599        fs::write(&file1_path, b"content1").unwrap();
600        fs::write(&file2_path, b"content2").unwrap();
601
602        let files = vec![
603            (
604                file1_path.to_str().unwrap().to_string(),
605                "batch/file1.txt".to_string(),
606            ),
607            (
608                file2_path.to_str().unwrap().to_string(),
609                "batch/file2.txt".to_string(),
610            ),
611        ];
612
613        let results = CloudStorageUtils::batch_upload(&client, &files).unwrap();
614
615        assert_eq!(results.len(), 2);
616        assert_eq!(results[0], "mock://bucket/batch/file1.txt");
617        assert_eq!(results[1], "mock://bucket/batch/file2.txt");
618
619        // Verify uploads
620        let content1 = client.download("batch/file1.txt").unwrap();
621        let content2 = client.download("batch/file2.txt").unwrap();
622        assert_eq!(content1, b"content1");
623        assert_eq!(content2, b"content2");
624    }
625}