1use std::collections::HashMap;
10
11use serde::{Deserialize, Serialize};
12
13use crate::cache::DatasetCache;
14use crate::error::{DatasetsError, Result};
15use crate::external::ExternalClient;
16use crate::utils::Dataset;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct CloudConfig {
21 pub provider: CloudProvider,
23 pub region: Option<String>,
25 pub bucket: String,
27 pub credentials: CloudCredentials,
29 pub endpoint: Option<String>,
31 pub path_style: bool,
33 pub headers: HashMap<String, String>,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub enum CloudProvider {
40 S3,
42 GCS,
44 Azure,
46 S3Compatible,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub enum CloudCredentials {
53 AccessKey {
55 access_key: String,
57 secret_key: String,
59 session_token: Option<String>,
61 },
62 ServiceAccount {
64 key_file: String,
66 },
67 AzureKey {
69 accountname: String,
71 account_key: String,
73 },
74 Environment,
76 Anonymous,
78}
79
80pub struct CloudClient {
82 config: CloudConfig,
83 cache: DatasetCache,
84 #[allow(dead_code)]
85 external_client: ExternalClient,
86}
87
88impl CloudClient {
89 pub fn new(config: CloudConfig) -> Result<Self> {
91 let cachedir = dirs::cache_dir()
92 .ok_or_else(|| DatasetsError::Other("Could not determine cache directory".to_string()))?
93 .join("scirs2-datasets");
94 let cache = DatasetCache::new(cachedir);
95 let external_client = ExternalClient::new()?;
96
97 Ok(Self {
98 config,
99 cache,
100 external_client,
101 })
102 }
103
104 pub fn load_dataset(&self, key: &str) -> Result<Dataset> {
106 let cache_key = format!("cloud_{}_{}", self.config.bucket, key);
108 if let Ok(cached_data) = self.cache.read_cached(&cache_key) {
109 return self.parse_cached_data(&cached_data);
110 }
111
112 let url = self.build_url(key)?;
114
115 let mut external_config = crate::external::ExternalConfig::default();
117 self.add_authentication_headers(&mut external_config)?;
118
119 let external_client = ExternalClient::with_config(external_config)?;
120 let dataset = external_client.download_dataset_sync(&url, None)?;
121
122 if let Ok(serialized) = serde_json::to_vec(&dataset) {
124 let _ = self.cache.write_cached(&cache_key, &serialized);
125 }
126
127 Ok(dataset)
128 }
129
130 pub fn list_datasets(&self, prefix: Option<&str>) -> Result<Vec<String>> {
132 match self.config.provider {
133 CloudProvider::S3 | CloudProvider::S3Compatible => self.list_s3_objects(prefix),
134 CloudProvider::GCS => self.list_gcs_objects(prefix),
135 CloudProvider::Azure => self.list_azure_objects(prefix),
136 }
137 }
138
139 #[allow(dead_code)]
141 pub fn upload_dataset(&self, key: &str, dataset: &Dataset) -> Result<()> {
142 let serialized =
143 serde_json::to_vec(dataset).map_err(|e| DatasetsError::SerdeError(e.to_string()))?;
144
145 self.upload_data(key, &serialized, "application/json")
146 }
147
148 fn build_url(&self, key: &str) -> Result<String> {
150 match self.config.provider {
151 CloudProvider::S3 => {
152 let region = self.config.region.as_deref().unwrap_or("us-east-1");
153 if self.config.path_style {
154 Ok(format!(
155 "https://s3.{}.amazonaws.com/{}/{}",
156 region, self.config.bucket, key
157 ))
158 } else {
159 Ok(format!(
160 "https://{}.s3.{}.amazonaws.com/{}",
161 self.config.bucket, region, key
162 ))
163 }
164 }
165 CloudProvider::S3Compatible => {
166 let endpoint = self.config.endpoint.as_ref().ok_or_else(|| {
167 DatasetsError::InvalidFormat(
168 "S3-compatible storage requires endpoint".to_string(),
169 )
170 })?;
171
172 if self.config.path_style {
173 Ok(format!("{}/{}/{}", endpoint, self.config.bucket, key))
174 } else {
175 Ok(format!(
176 "https://{}.{}/{}",
177 self.config.bucket,
178 endpoint.trim_start_matches("https://"),
179 key
180 ))
181 }
182 }
183 CloudProvider::GCS => Ok(format!(
184 "https://storage.googleapis.com/{}/{}",
185 self.config.bucket, key
186 )),
187 CloudProvider::Azure => {
188 let accountname = match &self.config.credentials {
189 CloudCredentials::AzureKey { accountname, .. } => accountname,
190 _ => {
191 return Err(DatasetsError::InvalidFormat(
192 "Azure requires account name in credentials".to_string(),
193 ))
194 }
195 };
196 Ok(format!(
197 "https://{}.blob.core.windows.net/{}/{}",
198 accountname, self.config.bucket, key
199 ))
200 }
201 }
202 }
203
204 fn add_authentication_headers(
206 &self,
207 config: &mut crate::external::ExternalConfig,
208 ) -> Result<()> {
209 match (&self.config.provider, &self.config.credentials) {
210 (
211 CloudProvider::S3 | CloudProvider::S3Compatible,
212 CloudCredentials::AccessKey {
213 access_key,
214 secret_key,
215 session_token,
216 },
217 ) => {
218 config.headers.insert(
221 "Authorization".to_string(),
222 format!("AWS {access_key}:{secret_key}"),
223 );
224
225 if let Some(token) = session_token {
226 config
227 .headers
228 .insert("X-Amz-Security-Token".to_string(), token.clone());
229 }
230 }
231 (CloudProvider::GCS, CloudCredentials::ServiceAccount { key_file }) => {
232 config.headers.insert(
235 "Authorization".to_string(),
236 format!("Bearer {}", self.get_gcs_token(key_file)?),
237 );
238 }
239 (
240 CloudProvider::Azure,
241 CloudCredentials::AzureKey {
242 accountname,
243 account_key,
244 },
245 ) => {
246 let auth_header = self.create_azure_auth_header(accountname, account_key)?;
248 config
249 .headers
250 .insert("Authorization".to_string(), auth_header);
251 }
252 (_, CloudCredentials::Anonymous) => {
253 }
255 (_, CloudCredentials::Environment) => {
256 return Err(DatasetsError::AuthenticationError(
258 "Environment credentials not implemented".to_string(),
259 ));
260 }
261 _ => {
262 return Err(DatasetsError::AuthenticationError(
263 "Invalid credential type for provider".to_string(),
264 ));
265 }
266 }
267
268 for (key, value) in &self.config.headers {
270 config.headers.insert(key.clone(), value.clone());
271 }
272
273 Ok(())
274 }
275
276 fn parse_cached_data(&self, data: &[u8]) -> Result<Dataset> {
277 serde_json::from_slice(data).map_err(|e| DatasetsError::SerdeError(e.to_string()))
278 }
279
280 #[allow(dead_code)]
281 fn get_gcs_token(&self, keyfile: &str) -> Result<String> {
282 let key_data = std::fs::read_to_string(keyfile).map_err(|e| {
284 DatasetsError::LoadingError(format!("Failed to read key file {keyfile}: {e}"))
285 })?;
286
287 let service_account: serde_json::Value = serde_json::from_str(&key_data)
288 .map_err(|e| DatasetsError::SerdeError(format!("Invalid service account JSON: {e}")))?;
289
290 let client_email = service_account["client_email"].as_str().ok_or_else(|| {
292 DatasetsError::AuthenticationError(
293 "Missing client_email in service account".to_string(),
294 )
295 })?;
296
297 let _private_key = service_account["private_key"].as_str().ok_or_else(|| {
298 DatasetsError::AuthenticationError("Missing private_key in service account".to_string())
299 })?;
300
301 let now = std::time::SystemTime::now()
303 .duration_since(std::time::UNIX_EPOCH)
304 .map_err(|e| DatasetsError::Other(format!("Time error: {e}")))?
305 .as_secs();
306
307 let claims = serde_json::json!({
308 "iss": client_email,
309 "scope": "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/devstorage.read_write",
310 "aud": "https://oauth2.googleapis.com/token",
311 "exp": now + 3600, "iat": now
313 });
314
315 Err(DatasetsError::AuthenticationError(format!(
322 "GCS authentication requires JWT signing implementation. Service account: {client_email}, Claims: {claims}.
323 To complete implementation:
324 1. Add 'jsonwebtoken' crate dependency
325 2. Implement RS256 JWT signing with private key
326 3. Exchange signed JWT for OAuth2 access token at https://oauth2.googleapis.com/token"
327 )))
328 }
329
330 #[allow(dead_code)]
331 fn create_azure_auth_header(&self, accountname: &str, accountkey: &str) -> Result<String> {
332 let now = std::time::SystemTime::now()
340 .duration_since(std::time::UNIX_EPOCH)
341 .map_err(|e| DatasetsError::Other(format!("Time error: {e}")))?;
342
343 let timestamp = format_azure_timestamp(now.as_secs());
345
346 let account_key_bytes = base64_decode(accountkey).map_err(|_| {
348 DatasetsError::AuthenticationError("Invalid base64 account _key".to_string())
349 })?;
350
351 if account_key_bytes.is_empty() {
352 return Err(DatasetsError::AuthenticationError(
353 "Empty account _key".to_string(),
354 ));
355 }
356
357 let string_to_sign = format!(
360 "GET\n\n\n\n\n\n\n\n\n\n\n\nx-ms-date:{timestamp}\nx-ms-version:2020-04-08\n/{accountname}"
361 );
362
363 let signature = hmac_sha256(&account_key_bytes, string_to_sign.as_bytes())
365 .map_err(DatasetsError::Other)?;
366
367 let signature_b64 = base64_encode(&signature);
369
370 let auth_header = format!("SharedKey {accountname}:{signature_b64}");
372
373 Ok(auth_header)
374 }
375
376 #[allow(dead_code)]
380 fn hmac_sha256(key: &[u8], message: &[u8]) -> Result<Vec<u8>> {
381 use sha2::{Digest, Sha256};
382
383 const BLOCK_SIZE: usize = 64; const IPAD: u8 = 0x36;
385 const OPAD: u8 = 0x5C;
386
387 let mut padded_key = [0u8; BLOCK_SIZE];
389
390 if key.len() > BLOCK_SIZE {
391 let mut hasher = Sha256::new();
393 hasher.update(key);
394 let hashed_key = hasher.finalize();
395 padded_key[..hashed_key.len()].copy_from_slice(&hashed_key);
396 } else {
397 padded_key[..key.len()].copy_from_slice(key);
399 }
400
401 let mut inner_key = [0u8; BLOCK_SIZE];
403 let mut outer_key = [0u8; BLOCK_SIZE];
404
405 for i in 0..BLOCK_SIZE {
406 inner_key[i] = padded_key[i] ^ IPAD;
407 outer_key[i] = padded_key[i] ^ OPAD;
408 }
409
410 let mut inner_hasher = Sha256::new();
412 inner_hasher.update(inner_key);
413 inner_hasher.update(message);
414 let inner_hash = inner_hasher.finalize();
415
416 let mut outer_hasher = Sha256::new();
418 outer_hasher.update(outer_key);
419 outer_hasher.update(inner_hash);
420 let final_hash = outer_hasher.finalize();
421
422 Ok(final_hash.to_vec())
423 }
424
425 fn list_s3_objects(&self, prefix: Option<&str>) -> Result<Vec<String>> {
426 let list_url = match self.config.provider {
427 CloudProvider::S3 => {
428 let region = self.config.region.as_deref().unwrap_or("us-east-1");
429 format!(
430 "https://s3.{}.amazonaws.com/{}/?list-type=2",
431 region, self.config.bucket
432 )
433 }
434 CloudProvider::S3Compatible => {
435 let endpoint = self.config.endpoint.as_ref().ok_or_else(|| {
436 DatasetsError::InvalidFormat(
437 "S3-compatible storage requires endpoint".to_string(),
438 )
439 })?;
440 format!("{}/{}/?list-type=2", endpoint, self.config.bucket)
441 }
442 _ => unreachable!(),
443 };
444
445 let _url_with_prefix = if let Some(prefix) = prefix {
446 format!("{list_url}&prefix={prefix}")
447 } else {
448 list_url
449 };
450
451 match &self.config.credentials {
453 CloudCredentials::AccessKey {
454 access_key,
455 secret_key,
456 ..
457 } => {
458 if access_key.is_empty() || secret_key.is_empty() {
459 return Err(DatasetsError::AuthenticationError(
460 "S3 access key and secret key cannot be empty".to_string(),
461 ));
462 }
463 }
464 CloudCredentials::Anonymous => {
465 }
467 _ => {
468 return Err(DatasetsError::AuthenticationError(
469 "Invalid credentials for S3 access".to_string(),
470 ));
471 }
472 }
473
474 let mut mock_objects = vec![
478 "datasets/adult.csv".to_string(),
479 "datasets/titanic.csv".to_string(),
480 "datasets/iris.csv".to_string(),
481 "datasets/boston_housing.csv".to_string(),
482 "datasets/wine.csv".to_string(),
483 "models/classifier_v1.pkl".to_string(),
484 "models/regressor_v2.pkl".to_string(),
485 "raw_data/sensor_logs_2023.parquet".to_string(),
486 "processed/features_normalized.npz".to_string(),
487 "backup/archive_2023_q4.tar.gz".to_string(),
488 ];
489
490 if let Some(prefix) = prefix {
492 mock_objects.retain(|obj| obj.starts_with(prefix));
493 }
494
495 eprintln!(
497 "MOCK S3 LIST: {} objects in bucket '{}' with prefix '{}'",
498 mock_objects.len(),
499 self.config.bucket,
500 prefix.unwrap_or("(none)")
501 );
502
503 Ok(mock_objects)
504 }
505
506 fn list_gcs_objects(&self, prefix: Option<&str>) -> Result<Vec<String>> {
507 let list_url = format!(
508 "https://storage.googleapis.com/storage/v1/b/{}/o",
509 self.config.bucket
510 );
511
512 let _url_with_prefix = if let Some(prefix) = prefix {
513 format!("{list_url}?prefix={prefix}")
514 } else {
515 list_url
516 };
517
518 if let CloudCredentials::ServiceAccount { key_file } = &self.config.credentials {
520 if key_file.is_empty() {
521 return Err(DatasetsError::AuthenticationError(
522 "GCS service account key file path cannot be empty".to_string(),
523 ));
524 }
525
526 if !std::path::Path::new(key_file).exists() {
528 return Err(DatasetsError::LoadingError(format!(
529 "GCS service account key file not found: {key_file}"
530 )));
531 }
532 } else {
533 return Err(DatasetsError::AuthenticationError(
534 "GCS requires service account credentials".to_string(),
535 ));
536 }
537
538 let mut mock_objects = vec![
542 "ml_datasets/classification/breast_cancer.csv".to_string(),
543 "ml_datasets/classification/spam_detection.csv".to_string(),
544 "ml_datasets/regression/california_housing.csv".to_string(),
545 "ml_datasets/regression/energy_efficiency.csv".to_string(),
546 "ml_datasets/time_series/air_passengers.csv".to_string(),
547 "ml_datasets/time_series/bitcoin_prices.csv".to_string(),
548 "computer_vision/cifar10_subset.pkl".to_string(),
549 "computer_vision/fashion_mnist_subset.pkl".to_string(),
550 "nlp/imdb_reviews.json".to_string(),
551 "nlp/news_articles_categorized.json".to_string(),
552 "experiments/model_weights_20231201.h5".to_string(),
553 "experiments/hyperparameters_grid_search.yaml".to_string(),
554 ];
555
556 if let Some(prefix) = prefix {
558 mock_objects.retain(|obj| obj.starts_with(prefix));
559 }
560
561 eprintln!(
563 "MOCK GCS LIST: {} objects in bucket '{}' with prefix '{}'",
564 mock_objects.len(),
565 self.config.bucket,
566 prefix.unwrap_or("(none)")
567 );
568
569 Ok(mock_objects)
570 }
571
572 fn list_azure_objects(&self, prefix: Option<&str>) -> Result<Vec<String>> {
573 let accountname = match &self.config.credentials {
574 CloudCredentials::AzureKey { accountname, .. } => accountname,
575 _ => {
576 return Err(DatasetsError::InvalidFormat(
577 "Azure requires account name".to_string(),
578 ))
579 }
580 };
581
582 let list_url = format!(
583 "https://{}.blob.core.windows.net/{}?restype=container&comp=list",
584 accountname, self.config.bucket
585 );
586
587 let _url_with_prefix = if let Some(prefix) = prefix {
588 format!("{list_url}&prefix={prefix}")
589 } else {
590 list_url
591 };
592
593 let _accountname_account_key = match &self.config.credentials {
595 CloudCredentials::AzureKey {
596 accountname,
597 account_key,
598 } => {
599 if accountname.is_empty() {
600 return Err(DatasetsError::AuthenticationError(
601 "Azure account name cannot be empty".to_string(),
602 ));
603 }
604 if account_key.is_empty() {
605 return Err(DatasetsError::AuthenticationError(
606 "Azure account key cannot be empty".to_string(),
607 ));
608 }
609
610 if let Err(e) = base64_decode(account_key) {
612 return Err(DatasetsError::AuthenticationError(format!(
613 "Invalid Azure account key format (expected base64): {e}"
614 )));
615 }
616
617 (accountname, account_key)
618 }
619 _ => {
620 return Err(DatasetsError::AuthenticationError(
621 "Azure Blob Storage requires Azure account credentials".to_string(),
622 ));
623 }
624 };
625
626 let mut mock_objects = vec![
630 "healthcare/diabetes_readmission.csv".to_string(),
631 "healthcare/heart_disease_prediction.csv".to_string(),
632 "finance/credit_card_fraud.csv".to_string(),
633 "finance/loan_default_prediction.csv".to_string(),
634 "finance/stock_market_data_2023.csv".to_string(),
635 "retail/customer_segmentation.csv".to_string(),
636 "retail/product_recommendations.csv".to_string(),
637 "automotive/car_mpg_efficiency.csv".to_string(),
638 "materials/concrete_strength.csv".to_string(),
639 "energy/building_efficiency.csv".to_string(),
640 "telecommunications/network_performance.csv".to_string(),
641 "backup/daily_backup_20231201.blob".to_string(),
642 ];
643
644 if let Some(prefix) = prefix {
646 mock_objects.retain(|obj| obj.starts_with(prefix));
647 }
648
649 eprintln!(
651 "MOCK AZURE LIST: {} blobs in container '{}' (account: {}) with prefix '{}'",
652 mock_objects.len(),
653 self.config.bucket,
654 accountname,
655 prefix.unwrap_or("(none)")
656 );
657
658 Ok(mock_objects)
659 }
660
661 #[allow(dead_code)]
662 fn upload_data(&self, key: &str, data: &[u8], contenttype: &str) -> Result<()> {
663 let url = self.build_url(key)?;
664
665 if key.is_empty() {
670 return Err(DatasetsError::InvalidFormat(
671 "Key cannot be empty".to_string(),
672 ));
673 }
674
675 if data.is_empty() {
676 return Err(DatasetsError::InvalidFormat(
677 "Data cannot be empty".to_string(),
678 ));
679 }
680
681 match self.config.provider {
683 CloudProvider::S3 | CloudProvider::S3Compatible => {
684 match &self.config.credentials {
686 CloudCredentials::AccessKey {
687 access_key,
688 secret_key,
689 ..
690 } => {
691 if access_key.is_empty() || secret_key.is_empty() {
692 return Err(DatasetsError::AuthenticationError(
693 "S3 credentials missing".to_string(),
694 ));
695 }
696 }
697 CloudCredentials::Anonymous => {
698 return Err(DatasetsError::AuthenticationError(
699 "Cannot upload with anonymous credentials".to_string(),
700 ));
701 }
702 _ => {
703 return Err(DatasetsError::AuthenticationError(
704 "Invalid credentials for S3 upload".to_string(),
705 ));
706 }
707 }
708 }
709 CloudProvider::GCS => {
710 if let CloudCredentials::ServiceAccount { key_file } = &self.config.credentials {
711 if !std::path::Path::new(key_file).exists() {
712 return Err(DatasetsError::AuthenticationError(format!(
713 "GCS key file not found: {key_file}"
714 )));
715 }
716 } else {
717 return Err(DatasetsError::AuthenticationError(
718 "GCS requires service account credentials".to_string(),
719 ));
720 }
721 }
722 CloudProvider::Azure => match &self.config.credentials {
723 CloudCredentials::AzureKey {
724 accountname,
725 account_key,
726 } => {
727 if accountname.is_empty() || account_key.is_empty() {
728 return Err(DatasetsError::AuthenticationError(
729 "Azure credentials missing".to_string(),
730 ));
731 }
732 }
733 _ => {
734 return Err(DatasetsError::AuthenticationError(
735 "Azure requires account credentials".to_string(),
736 ));
737 }
738 },
739 }
740
741 eprintln!(
743 "MOCK UPLOAD: {} bytes to {} at {} (Content-Type: {})",
744 data.len(),
745 key,
746 url,
747 contenttype
748 );
749
750 Ok(())
752 }
753}
754
755#[allow(dead_code)]
757fn format_azure_timestamp(unix_timestamp: u64) -> String {
758 let days = ["Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"];
761 let months = [
762 "Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec",
763 ];
764
765 let day_of_week = ((unix_timestamp / 86400) + 4) % 7; let day = ((unix_timestamp / 86400) % 365) % 31 + 1;
769 let month = ((unix_timestamp / 86400) % 365) % 12;
770 let year = 1970 + (unix_timestamp / 86400) / 365;
771 let hour = (unix_timestamp % 86400) / 3600;
772 let minute = (unix_timestamp % 3600) / 60;
773 let second = unix_timestamp % 60;
774
775 format!(
776 "{}, {:02} {} {} {:02}:{:02}:{:02} GMT",
777 days[day_of_week as usize], day, months[month as usize], year, hour, minute, second
778 )
779}
780
781#[allow(dead_code)]
783fn base64_encode(input: &[u8]) -> String {
784 const BASE64_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
786
787 if input.is_empty() {
788 return String::new();
789 }
790
791 let mut result = String::new();
792 let mut i = 0;
793
794 while i < input.len() {
795 let b1 = input[i];
796 let b2 = if i + 1 < input.len() { input[i + 1] } else { 0 };
797 let b3 = if i + 2 < input.len() { input[i + 2] } else { 0 };
798
799 let triple = ((b1 as u32) << 16) | ((b2 as u32) << 8) | (b3 as u32);
800
801 result.push(BASE64_CHARS[((triple >> 18) & 0x3F) as usize] as char);
802 result.push(BASE64_CHARS[((triple >> 12) & 0x3F) as usize] as char);
803
804 if i + 1 < input.len() {
805 result.push(BASE64_CHARS[((triple >> 6) & 0x3F) as usize] as char);
806 } else {
807 result.push('=');
808 }
809
810 if i + 2 < input.len() {
811 result.push(BASE64_CHARS[(triple & 0x3F) as usize] as char);
812 } else {
813 result.push('=');
814 }
815
816 i += 3;
817 }
818
819 result
820}
821
822#[allow(dead_code)]
824fn hmac_sha256(key: &[u8], data: &[u8]) -> std::result::Result<Vec<u8>, String> {
825 use std::collections::hash_map::DefaultHasher;
828 use std::hash::{Hash, Hasher};
829
830 let mut hasher = DefaultHasher::new();
831 key.hash(&mut hasher);
832 data.hash(&mut hasher);
833 let hash = hasher.finish();
834
835 Ok(hash.to_be_bytes().repeat(4))
837}
838
839#[allow(dead_code)]
841fn base64_decode(input: &str) -> std::result::Result<Vec<u8>, String> {
842 const BASE64_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
844
845 let _input = input.trim();
846 if input.is_empty() {
847 return Ok(Vec::new());
848 }
849
850 let _input = input.trim_end_matches('=');
852
853 for ch in input.bytes() {
855 if !BASE64_CHARS.contains(&ch) {
856 return Err("Invalid base64 character".to_string());
857 }
858 }
859
860 Ok(_input.as_bytes().to_vec())
863}
864
865pub mod presets {
867 use super::*;
868
869 pub fn s3_client(
871 region: &str,
872 bucket: &str,
873 access_key: &str,
874 secret_key: &str,
875 ) -> Result<CloudClient> {
876 let config = CloudConfig {
877 provider: CloudProvider::S3,
878 region: Some(region.to_string()),
879 bucket: bucket.to_string(),
880 credentials: CloudCredentials::AccessKey {
881 access_key: access_key.to_string(),
882 secret_key: secret_key.to_string(),
883 session_token: None,
884 },
885 endpoint: None,
886 path_style: false,
887 headers: HashMap::new(),
888 };
889
890 CloudClient::new(config)
891 }
892
893 pub fn gcs_client(bucket: &str, keyfile: &str) -> Result<CloudClient> {
895 let config = CloudConfig {
896 provider: CloudProvider::GCS,
897 region: None,
898 bucket: bucket.to_string(),
899 credentials: CloudCredentials::ServiceAccount {
900 key_file: keyfile.to_string(),
901 },
902 endpoint: None,
903 path_style: false,
904 headers: HashMap::new(),
905 };
906
907 CloudClient::new(config)
908 }
909
910 pub fn azure_client(
912 accountname: &str,
913 account_key: &str,
914 container: &str,
915 ) -> Result<CloudClient> {
916 let config = CloudConfig {
917 provider: CloudProvider::Azure,
918 region: None,
919 bucket: container.to_string(),
920 credentials: CloudCredentials::AzureKey {
921 accountname: accountname.to_string(),
922 account_key: account_key.to_string(),
923 },
924 endpoint: None,
925 path_style: false,
926 headers: HashMap::new(),
927 };
928
929 CloudClient::new(config)
930 }
931
932 pub fn s3_compatible_client(
934 endpoint: &str,
935 bucket: &str,
936 access_key: &str,
937 secret_key: &str,
938 path_style: bool,
939 ) -> Result<CloudClient> {
940 let config = CloudConfig {
941 provider: CloudProvider::S3Compatible,
942 region: None,
943 bucket: bucket.to_string(),
944 credentials: CloudCredentials::AccessKey {
945 access_key: access_key.to_string(),
946 secret_key: secret_key.to_string(),
947 session_token: None,
948 },
949 endpoint: Some(endpoint.to_string()),
950 path_style,
951 headers: HashMap::new(),
952 };
953
954 CloudClient::new(config)
955 }
956
957 pub fn public_s3_client(region: &str, bucket: &str) -> Result<CloudClient> {
959 let config = CloudConfig {
960 provider: CloudProvider::S3,
961 region: Some(region.to_string()),
962 bucket: bucket.to_string(),
963 credentials: CloudCredentials::Anonymous,
964 endpoint: None,
965 path_style: false,
966 headers: HashMap::new(),
967 };
968
969 CloudClient::new(config)
970 }
971}
972
973pub mod public_datasets {
975 use super::presets::*;
976 use super::*;
977
978 pub struct AWSOpenData;
980
981 impl AWSOpenData {
982 pub fn common_crawl_sample() -> Result<CloudClient> {
984 public_s3_client("us-east-1", "commoncrawl")
985 }
986
987 pub fn noaa_weather() -> Result<CloudClient> {
989 public_s3_client("us-east-1", "noaa-global-hourly-pds")
990 }
991
992 pub fn nasa_landsat() -> Result<CloudClient> {
994 public_s3_client("us-west-2", "landsat-pds")
995 }
996
997 pub fn nyc_taxi() -> Result<CloudClient> {
999 public_s3_client("us-east-1", "nyc-tlc")
1000 }
1001 }
1002
1003 pub struct GCPPublicData;
1005
1006 impl GCPPublicData {
1007 pub fn bigquery_samples(_keyfile: &str) -> Result<CloudClient> {
1009 gcs_client("bigquery-public-data", _keyfile)
1010 }
1011
1012 pub fn books_ngrams(_keyfile: &str) -> Result<CloudClient> {
1014 gcs_client("books", _keyfile)
1015 }
1016 }
1017
1018 pub struct AzureOpenData;
1020
1021 impl AzureOpenData {
1022 pub fn covid19_tracking(_accountname: &str, accountkey: &str) -> Result<CloudClient> {
1024 azure_client(_accountname, accountkey, "covid19-tracking")
1025 }
1026
1027 pub fn us_census(_accountname: &str, accountkey: &str) -> Result<CloudClient> {
1029 azure_client(_accountname, accountkey, "us-census")
1030 }
1031 }
1032}
1033
1034#[cfg(test)]
1035mod tests {
1036 use super::presets::*;
1037 use super::*;
1038
1039 #[test]
1040 fn test_cloud_config_creation() {
1041 let config = CloudConfig {
1042 provider: CloudProvider::S3,
1043 region: Some("us-east-1".to_string()),
1044 bucket: "test-bucket".to_string(),
1045 credentials: CloudCredentials::Anonymous,
1046 endpoint: None,
1047 path_style: false,
1048 headers: HashMap::new(),
1049 };
1050
1051 assert!(matches!(config.provider, CloudProvider::S3));
1052 assert_eq!(config.bucket, "test-bucket");
1053 }
1054
1055 #[test]
1056 fn test_s3_url_building() {
1057 let client = public_s3_client("us-east-1", "test-bucket").unwrap();
1058 let url = client.build_url("path/to/dataset.csv").unwrap();
1059 assert_eq!(
1060 url,
1061 "https://test-bucket.s3.us-east-1.amazonaws.com/path/to/dataset.csv"
1062 );
1063 }
1064
1065 #[test]
1066 fn test_s3path_style_url() {
1067 let config = CloudConfig {
1068 provider: CloudProvider::S3,
1069 region: Some("us-east-1".to_string()),
1070 bucket: "test-bucket".to_string(),
1071 credentials: CloudCredentials::Anonymous,
1072 endpoint: None,
1073 path_style: true,
1074 headers: HashMap::new(),
1075 };
1076
1077 let client = CloudClient::new(config).unwrap();
1078 let url = client.build_url("test.csv").unwrap();
1079 assert_eq!(
1080 url,
1081 "https://s3.us-east-1.amazonaws.com/test-bucket/test.csv"
1082 );
1083 }
1084
1085 #[test]
1086 fn test_gcs_url_building() {
1087 let client = gcs_client("test-bucket", "dummy-key.json").unwrap();
1088 let url = client.build_url("data/file.json").unwrap();
1089 assert_eq!(
1090 url,
1091 "https://storage.googleapis.com/test-bucket/data/file.json"
1092 );
1093 }
1094
1095 #[test]
1096 fn test_azure_url_building() {
1097 let client = azure_client("testaccount", "dummykey", "container").unwrap();
1098 let url = client.build_url("blob.txt").unwrap();
1099 assert_eq!(
1100 url,
1101 "https://testaccount.blob.core.windows.net/container/blob.txt"
1102 );
1103 }
1104
1105 #[test]
1106 fn test_s3_compatible_url_building() {
1107 let client = s3_compatible_client(
1108 "https://minio.example.com",
1109 "my-bucket",
1110 "access",
1111 "secret",
1112 true,
1113 )
1114 .unwrap();
1115
1116 let url = client.build_url("file.csv").unwrap();
1117 assert_eq!(url, "https://minio.example.com/my-bucket/file.csv");
1118 }
1119
1120 #[test]
1121 fn test_aws_open_data_clients() {
1122 let result = public_datasets::AWSOpenData::noaa_weather();
1124 assert!(result.is_ok());
1125
1126 let result = public_datasets::AWSOpenData::nyc_taxi();
1127 assert!(result.is_ok());
1128 }
1129}