Skip to main content

s3rm_rs/storage/s3/
client_builder.rs

1use aws_config::meta::region::{ProvideRegion, RegionProviderChain};
2use aws_config::retry::RetryConfig;
3use aws_config::{BehaviorVersion, ConfigLoader};
4use aws_runtime::env_config::file::{EnvConfigFileKind, EnvConfigFiles};
5use aws_sdk_s3::Client;
6use aws_sdk_s3::config::Builder;
7use std::time::Duration;
8
9use crate::config::ClientConfig;
10use aws_smithy_runtime_api::client::stalled_stream_protection::StalledStreamProtectionConfig;
11use aws_smithy_types::timeout::TimeoutConfig;
12use aws_types::SdkConfig;
13use aws_types::region::Region;
14
15impl ClientConfig {
16    pub async fn create_client(&self) -> Client {
17        let mut config_builder = Builder::from(&self.load_sdk_config().await)
18            .force_path_style(self.force_path_style)
19            .request_checksum_calculation(self.request_checksum_calculation)
20            .accelerate(self.accelerate);
21
22        if let Some(timeout_config) = self.build_timeout_config() {
23            config_builder = config_builder.timeout_config(timeout_config);
24        }
25
26        Client::from_conf(config_builder.build())
27    }
28
29    async fn load_sdk_config(&self) -> SdkConfig {
30        let config_loader = if self.disable_stalled_stream_protection {
31            aws_config::defaults(BehaviorVersion::latest())
32                .stalled_stream_protection(StalledStreamProtectionConfig::disabled())
33        } else {
34            aws_config::defaults(BehaviorVersion::latest())
35                .stalled_stream_protection(StalledStreamProtectionConfig::enabled().build())
36        };
37        let mut config_loader = self
38            .load_config_credential(config_loader)
39            .region(self.build_region_provider())
40            .retry_config(self.build_retry_config());
41
42        if let Some(endpoint_url) = &self.endpoint_url {
43            config_loader = config_loader.endpoint_url(endpoint_url);
44        };
45
46        config_loader.load().await
47    }
48
49    fn load_config_credential(&self, mut config_loader: ConfigLoader) -> ConfigLoader {
50        match &self.credential {
51            crate::types::S3Credentials::Credentials { access_keys } => {
52                let credentials = aws_sdk_s3::config::Credentials::new(
53                    access_keys.access_key.to_string(),
54                    access_keys.secret_access_key.to_string(),
55                    access_keys.session_token.clone(),
56                    None,
57                    "",
58                );
59                config_loader = config_loader.credentials_provider(credentials);
60            }
61            crate::types::S3Credentials::Profile(profile_name) => {
62                let mut builder = aws_config::profile::ProfileFileCredentialsProvider::builder();
63
64                if let Some(aws_shared_credentials_file) = self
65                    .client_config_location
66                    .aws_shared_credentials_file
67                    .as_ref()
68                {
69                    let profile_files = EnvConfigFiles::builder()
70                        .with_file(EnvConfigFileKind::Credentials, aws_shared_credentials_file)
71                        .build();
72                    builder = builder.profile_files(profile_files)
73                }
74
75                config_loader =
76                    config_loader.credentials_provider(builder.profile_name(profile_name).build());
77            }
78            crate::types::S3Credentials::FromEnvironment => {}
79        }
80        config_loader
81    }
82
83    fn build_region_provider(&self) -> Box<dyn ProvideRegion> {
84        let mut builder = aws_config::profile::ProfileFileRegionProvider::builder();
85
86        if let crate::types::S3Credentials::Profile(profile_name) = &self.credential {
87            if let Some(aws_config_file) = self.client_config_location.aws_config_file.as_ref() {
88                let profile_files = EnvConfigFiles::builder()
89                    .with_file(EnvConfigFileKind::Config, aws_config_file)
90                    .build();
91                builder = builder.profile_files(profile_files);
92            }
93            builder = builder.profile_name(profile_name)
94        }
95
96        let provider_region = if matches!(
97            &self.credential,
98            crate::types::S3Credentials::FromEnvironment
99        ) {
100            RegionProviderChain::first_try(self.region.clone().map(Region::new))
101                .or_default_provider()
102        } else {
103            RegionProviderChain::first_try(self.region.clone().map(Region::new))
104                .or_else(builder.build())
105        };
106
107        Box::new(provider_region)
108    }
109
110    fn build_retry_config(&self) -> RetryConfig {
111        RetryConfig::standard()
112            .with_max_attempts(self.retry_config.aws_max_attempts)
113            .with_initial_backoff(std::time::Duration::from_millis(
114                self.retry_config.initial_backoff_milliseconds,
115            ))
116    }
117
118    fn build_timeout_config(&self) -> Option<TimeoutConfig> {
119        let operation_timeout = self
120            .cli_timeout_config
121            .operation_timeout_milliseconds
122            .map(Duration::from_millis);
123        let operation_attempt_timeout = self
124            .cli_timeout_config
125            .operation_attempt_timeout_milliseconds
126            .map(Duration::from_millis);
127        let connect_timeout = self
128            .cli_timeout_config
129            .connect_timeout_milliseconds
130            .map(Duration::from_millis);
131        let read_timeout = self
132            .cli_timeout_config
133            .read_timeout_milliseconds
134            .map(Duration::from_millis);
135
136        if operation_timeout.is_none()
137            && operation_attempt_timeout.is_none()
138            && connect_timeout.is_none()
139            && read_timeout.is_none()
140        {
141            return None;
142        }
143
144        let mut builder = TimeoutConfig::builder();
145
146        builder = if let Some(operation_timeout) = operation_timeout {
147            builder.operation_timeout(operation_timeout)
148        } else {
149            builder
150        };
151
152        builder = if let Some(operation_attempt_timeout) = operation_attempt_timeout {
153            builder.operation_attempt_timeout(operation_attempt_timeout)
154        } else {
155            builder
156        };
157
158        builder = if let Some(connect_timeout) = connect_timeout {
159            builder.connect_timeout(connect_timeout)
160        } else {
161            builder
162        };
163
164        builder = if let Some(read_timeout) = read_timeout {
165            builder.read_timeout(read_timeout)
166        } else {
167            builder
168        };
169
170        Some(builder.build())
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use crate::test_utils::init_dummy_tracing_subscriber;
178    use crate::types::{AccessKeys, ClientConfigLocation};
179    use aws_smithy_types::checksum_config::RequestChecksumCalculation;
180
181    #[tokio::test]
182    async fn create_client_from_credentials() {
183        init_dummy_tracing_subscriber();
184
185        let client_config = ClientConfig {
186            client_config_location: ClientConfigLocation {
187                aws_config_file: None,
188                aws_shared_credentials_file: None,
189            },
190            credential: crate::types::S3Credentials::Credentials {
191                access_keys: AccessKeys {
192                    access_key: "my_access_key".to_string(),
193                    secret_access_key: "my_secret_access_key".to_string(),
194                    session_token: Some("my_session_token".to_string()),
195                },
196            },
197            region: Some("my-region".to_string()),
198            endpoint_url: Some("https://my.endpoint.local".to_string()),
199            force_path_style: false,
200            retry_config: crate::config::RetryConfig {
201                aws_max_attempts: 10,
202                initial_backoff_milliseconds: 100,
203            },
204            cli_timeout_config: crate::config::CLITimeoutConfig {
205                operation_timeout_milliseconds: None,
206                operation_attempt_timeout_milliseconds: None,
207                connect_timeout_milliseconds: None,
208                read_timeout_milliseconds: None,
209            },
210            disable_stalled_stream_protection: false,
211            request_checksum_calculation: RequestChecksumCalculation::WhenRequired,
212            accelerate: false,
213            request_payer: None,
214        };
215
216        let client = client_config.create_client().await;
217
218        let retry_config = client.config().retry_config().unwrap();
219        assert_eq!(retry_config.max_attempts(), 10);
220        assert_eq!(
221            retry_config.initial_backoff(),
222            std::time::Duration::from_millis(100)
223        );
224
225        let timeout_config = client.config().timeout_config().unwrap();
226        assert!(timeout_config.operation_timeout().is_none());
227        assert!(timeout_config.operation_attempt_timeout().is_none());
228        assert!(timeout_config.connect_timeout().is_some());
229        assert!(timeout_config.read_timeout().is_none());
230        assert!(timeout_config.has_timeouts());
231
232        assert_eq!(
233            client.config().region().unwrap().to_string(),
234            "my-region".to_string()
235        );
236    }
237
238    #[tokio::test]
239    async fn create_client_from_credentials_with_custom_timeouts() {
240        init_dummy_tracing_subscriber();
241
242        let client_config = ClientConfig {
243            client_config_location: ClientConfigLocation {
244                aws_config_file: None,
245                aws_shared_credentials_file: None,
246            },
247            credential: crate::types::S3Credentials::Credentials {
248                access_keys: AccessKeys {
249                    access_key: "my_access_key".to_string(),
250                    secret_access_key: "my_secret_access_key".to_string(),
251                    session_token: None,
252                },
253            },
254            region: None,
255            endpoint_url: Some("https://my.endpoint.local".to_string()),
256            force_path_style: false,
257            retry_config: crate::config::RetryConfig {
258                aws_max_attempts: 5,
259                initial_backoff_milliseconds: 200,
260            },
261            cli_timeout_config: crate::config::CLITimeoutConfig {
262                operation_timeout_milliseconds: Some(1000),
263                operation_attempt_timeout_milliseconds: Some(2000),
264                connect_timeout_milliseconds: Some(3000),
265                read_timeout_milliseconds: Some(4000),
266            },
267            disable_stalled_stream_protection: false,
268            request_checksum_calculation: RequestChecksumCalculation::WhenRequired,
269            accelerate: false,
270            request_payer: None,
271        };
272
273        let client = client_config.create_client().await;
274
275        let retry_config = client.config().retry_config().unwrap();
276        assert_eq!(retry_config.max_attempts(), 5);
277        assert_eq!(
278            retry_config.initial_backoff(),
279            std::time::Duration::from_millis(200)
280        );
281
282        let timeout_config = client.config().timeout_config().unwrap();
283        assert_eq!(
284            timeout_config.operation_timeout(),
285            Some(Duration::from_millis(1000))
286        );
287        assert_eq!(
288            timeout_config.operation_attempt_timeout(),
289            Some(Duration::from_millis(2000))
290        );
291        assert_eq!(
292            timeout_config.connect_timeout(),
293            Some(Duration::from_millis(3000))
294        );
295        assert_eq!(
296            timeout_config.read_timeout(),
297            Some(Duration::from_millis(4000))
298        );
299        assert!(timeout_config.has_timeouts());
300    }
301
302    #[tokio::test]
303    async fn create_client_from_credentials_with_default_region() {
304        init_dummy_tracing_subscriber();
305
306        let client_config = ClientConfig {
307            client_config_location: ClientConfigLocation {
308                aws_config_file: None,
309                aws_shared_credentials_file: None,
310            },
311            credential: crate::types::S3Credentials::Credentials {
312                access_keys: AccessKeys {
313                    access_key: "my_access_key".to_string(),
314                    secret_access_key: "my_secret_access_key".to_string(),
315                    session_token: Some("my_session_token".to_string()),
316                },
317            },
318            region: None,
319            endpoint_url: Some("https://my.endpoint.local".to_string()),
320            force_path_style: false,
321            retry_config: crate::config::RetryConfig {
322                aws_max_attempts: 10,
323                initial_backoff_milliseconds: 100,
324            },
325            cli_timeout_config: crate::config::CLITimeoutConfig {
326                operation_timeout_milliseconds: Some(1000),
327                operation_attempt_timeout_milliseconds: Some(2000),
328                connect_timeout_milliseconds: Some(3000),
329                read_timeout_milliseconds: Some(4000),
330            },
331            disable_stalled_stream_protection: false,
332            request_checksum_calculation: RequestChecksumCalculation::WhenRequired,
333            accelerate: false,
334            request_payer: None,
335        };
336
337        let client = client_config.create_client().await;
338
339        let retry_config = client.config().retry_config().unwrap();
340        assert_eq!(retry_config.max_attempts(), 10);
341        assert_eq!(
342            retry_config.initial_backoff(),
343            std::time::Duration::from_millis(100)
344        );
345
346        let timeout_config = client.config().timeout_config().unwrap();
347        assert_eq!(
348            timeout_config.operation_timeout(),
349            Some(Duration::from_millis(1000))
350        );
351        assert_eq!(
352            timeout_config.operation_attempt_timeout(),
353            Some(Duration::from_millis(2000))
354        );
355        assert_eq!(
356            timeout_config.connect_timeout(),
357            Some(Duration::from_millis(3000))
358        );
359        assert_eq!(
360            timeout_config.read_timeout(),
361            Some(Duration::from_millis(4000))
362        );
363        assert!(timeout_config.has_timeouts());
364    }
365
366    #[tokio::test]
367    async fn create_client_from_custom_profile() {
368        init_dummy_tracing_subscriber();
369
370        let client_config = ClientConfig {
371            client_config_location: ClientConfigLocation {
372                aws_config_file: Some("./test_data/test_config/config".into()),
373                aws_shared_credentials_file: Some("./test_data/test_config/credentials".into()),
374            },
375            credential: crate::types::S3Credentials::Profile("aws".to_string()),
376            region: Some("my-region".to_string()),
377            endpoint_url: Some("https://my.endpoint.local".to_string()),
378            force_path_style: false,
379            retry_config: crate::config::RetryConfig {
380                aws_max_attempts: 10,
381                initial_backoff_milliseconds: 100,
382            },
383            cli_timeout_config: crate::config::CLITimeoutConfig {
384                operation_timeout_milliseconds: None,
385                operation_attempt_timeout_milliseconds: None,
386                connect_timeout_milliseconds: None,
387                read_timeout_milliseconds: None,
388            },
389            disable_stalled_stream_protection: false,
390            request_checksum_calculation: RequestChecksumCalculation::WhenRequired,
391            accelerate: false,
392            request_payer: None,
393        };
394
395        let client = client_config.create_client().await;
396
397        let retry_config = client.config().retry_config().unwrap();
398        assert_eq!(retry_config.max_attempts(), 10);
399        assert_eq!(
400            retry_config.initial_backoff(),
401            std::time::Duration::from_millis(100)
402        );
403
404        assert_eq!(
405            client.config().region().unwrap().to_string(),
406            "my-region".to_string()
407        );
408    }
409
410    #[tokio::test]
411    async fn create_client_from_default_profile() {
412        init_dummy_tracing_subscriber();
413
414        let client_config = ClientConfig {
415            client_config_location: ClientConfigLocation {
416                aws_config_file: Some("./test_data/test_config/config".into()),
417                aws_shared_credentials_file: Some("./test_data/test_config/credentials".into()),
418            },
419            credential: crate::types::S3Credentials::Profile("default".to_string()),
420            region: None,
421            endpoint_url: Some("https://my.endpoint.local".to_string()),
422            force_path_style: false,
423            retry_config: crate::config::RetryConfig {
424                aws_max_attempts: 10,
425                initial_backoff_milliseconds: 100,
426            },
427            cli_timeout_config: crate::config::CLITimeoutConfig {
428                operation_timeout_milliseconds: None,
429                operation_attempt_timeout_milliseconds: None,
430                connect_timeout_milliseconds: None,
431                read_timeout_milliseconds: None,
432            },
433            disable_stalled_stream_protection: false,
434            request_checksum_calculation: RequestChecksumCalculation::WhenRequired,
435            accelerate: false,
436            request_payer: None,
437        };
438
439        let client = client_config.create_client().await;
440
441        let retry_config = client.config().retry_config().unwrap();
442        assert_eq!(retry_config.max_attempts(), 10);
443        assert_eq!(
444            retry_config.initial_backoff(),
445            std::time::Duration::from_millis(100)
446        );
447
448        assert_eq!(
449            client.config().region().unwrap().to_string(),
450            "us-west-1".to_string()
451        );
452    }
453
454    #[tokio::test]
455    async fn create_client_from_custom_profile_overriding_region() {
456        init_dummy_tracing_subscriber();
457
458        let client_config = ClientConfig {
459            client_config_location: ClientConfigLocation {
460                aws_config_file: Some("./test_data/test_config/config".into()),
461                aws_shared_credentials_file: Some("./test_data/test_config/credentials".into()),
462            },
463            credential: crate::types::S3Credentials::Profile("aws".to_string()),
464            region: Some("my-region2".to_string()),
465            endpoint_url: Some("https://my.endpoint.local".to_string()),
466            force_path_style: false,
467            retry_config: crate::config::RetryConfig {
468                aws_max_attempts: 10,
469                initial_backoff_milliseconds: 100,
470            },
471            cli_timeout_config: crate::config::CLITimeoutConfig {
472                operation_timeout_milliseconds: None,
473                operation_attempt_timeout_milliseconds: None,
474                connect_timeout_milliseconds: None,
475                read_timeout_milliseconds: None,
476            },
477            disable_stalled_stream_protection: false,
478            request_checksum_calculation: RequestChecksumCalculation::WhenRequired,
479            accelerate: false,
480            request_payer: None,
481        };
482
483        let client = client_config.create_client().await;
484
485        let retry_config = client.config().retry_config().unwrap();
486        assert_eq!(retry_config.max_attempts(), 10);
487        assert_eq!(
488            retry_config.initial_backoff(),
489            std::time::Duration::from_millis(100)
490        );
491
492        assert_eq!(
493            client.config().region().unwrap().to_string(),
494            "my-region2".to_string()
495        );
496    }
497
498    #[tokio::test]
499    async fn create_client_from_custom_timeout_connect_only() {
500        init_dummy_tracing_subscriber();
501
502        let client_config = ClientConfig {
503            client_config_location: ClientConfigLocation {
504                aws_config_file: Some("./test_data/test_config/config".into()),
505                aws_shared_credentials_file: Some("./test_data/test_config/credentials".into()),
506            },
507            credential: crate::types::S3Credentials::Profile("aws".to_string()),
508            region: Some("my-region".to_string()),
509            endpoint_url: Some("https://my.endpoint.local".to_string()),
510            force_path_style: false,
511            retry_config: crate::config::RetryConfig {
512                aws_max_attempts: 10,
513                initial_backoff_milliseconds: 100,
514            },
515            cli_timeout_config: crate::config::CLITimeoutConfig {
516                operation_timeout_milliseconds: None,
517                operation_attempt_timeout_milliseconds: None,
518                connect_timeout_milliseconds: Some(1000),
519                read_timeout_milliseconds: None,
520            },
521            disable_stalled_stream_protection: false,
522            request_checksum_calculation: RequestChecksumCalculation::WhenRequired,
523            accelerate: false,
524            request_payer: None,
525        };
526
527        let client = client_config.create_client().await;
528
529        let timeout_config = client.config().timeout_config().unwrap();
530        assert!(timeout_config.operation_timeout().is_none());
531        assert!(timeout_config.operation_attempt_timeout().is_none());
532        assert_eq!(
533            timeout_config.connect_timeout(),
534            Some(Duration::from_millis(1000))
535        );
536        assert!(timeout_config.read_timeout().is_none());
537        assert!(timeout_config.has_timeouts());
538    }
539
540    #[tokio::test]
541    async fn create_client_from_custom_timeout_operation_only() {
542        init_dummy_tracing_subscriber();
543
544        let client_config = ClientConfig {
545            client_config_location: ClientConfigLocation {
546                aws_config_file: Some("./test_data/test_config/config".into()),
547                aws_shared_credentials_file: Some("./test_data/test_config/credentials".into()),
548            },
549            credential: crate::types::S3Credentials::Profile("aws".to_string()),
550            region: Some("my-region".to_string()),
551            endpoint_url: Some("https://my.endpoint.local".to_string()),
552            force_path_style: false,
553            retry_config: crate::config::RetryConfig {
554                aws_max_attempts: 10,
555                initial_backoff_milliseconds: 100,
556            },
557            cli_timeout_config: crate::config::CLITimeoutConfig {
558                operation_timeout_milliseconds: Some(1000),
559                operation_attempt_timeout_milliseconds: None,
560                connect_timeout_milliseconds: None,
561                read_timeout_milliseconds: None,
562            },
563            disable_stalled_stream_protection: false,
564            request_checksum_calculation: RequestChecksumCalculation::WhenRequired,
565            accelerate: false,
566            request_payer: None,
567        };
568
569        let client = client_config.create_client().await;
570
571        let timeout_config = client.config().timeout_config().unwrap();
572        assert_eq!(
573            timeout_config.operation_timeout(),
574            Some(Duration::from_millis(1000))
575        );
576        assert!(timeout_config.operation_attempt_timeout().is_none());
577        assert!(timeout_config.connect_timeout().is_some());
578        assert!(timeout_config.read_timeout().is_none());
579        assert!(timeout_config.has_timeouts());
580    }
581
582    #[tokio::test]
583    async fn create_client_from_environment() {
584        init_dummy_tracing_subscriber();
585
586        let client_config = ClientConfig {
587            client_config_location: ClientConfigLocation {
588                aws_config_file: None,
589                aws_shared_credentials_file: None,
590            },
591            credential: crate::types::S3Credentials::FromEnvironment,
592            region: Some("us-east-1".to_string()),
593            endpoint_url: Some("https://my.endpoint.local".to_string()),
594            force_path_style: false,
595            retry_config: crate::config::RetryConfig {
596                aws_max_attempts: 3,
597                initial_backoff_milliseconds: 100,
598            },
599            cli_timeout_config: crate::config::CLITimeoutConfig {
600                operation_timeout_milliseconds: None,
601                operation_attempt_timeout_milliseconds: None,
602                connect_timeout_milliseconds: None,
603                read_timeout_milliseconds: None,
604            },
605            disable_stalled_stream_protection: false,
606            request_checksum_calculation: RequestChecksumCalculation::WhenRequired,
607            accelerate: false,
608            request_payer: None,
609        };
610
611        let client = client_config.create_client().await;
612
613        assert_eq!(
614            client.config().region().unwrap().to_string(),
615            "us-east-1".to_string()
616        );
617    }
618}