Skip to main content

rusticity_core/
config.rs

1use anyhow::Result;
2
3#[derive(Clone, Debug)]
4pub struct AwsConfig {
5    pub region: String,
6    pub account_id: String,
7    pub role_arn: String,
8    pub region_auto_detected: bool,
9}
10
11impl AwsConfig {
12    pub async fn new(region: Option<String>) -> Result<Self> {
13        Self::new_with_timeout(region, std::time::Duration::from_secs(10)).await
14    }
15
16    pub async fn new_with_timeout(
17        region: Option<String>,
18        timeout: std::time::Duration,
19    ) -> Result<Self> {
20        // Check for region early to avoid IMDS timeout
21        if region.is_none()
22            && std::env::var("AWS_REGION").is_err()
23            && std::env::var("AWS_DEFAULT_REGION").is_err()
24        {
25            return Err(anyhow::anyhow!("Missing Region"));
26        }
27
28        let mut config_loader = aws_config::defaults(aws_config::BehaviorVersion::latest());
29
30        // Use profile from AWS_PROFILE env var if set
31        let profile_to_use = std::env::var("AWS_PROFILE").ok();
32        if let Some(ref profile) = profile_to_use {
33            if !profile.is_empty() {
34                tracing::info!("Using AWS profile: {}", profile);
35                config_loader = config_loader.profile_name(profile);
36            }
37        }
38
39        if let Some(r) = &region {
40            config_loader = config_loader.region(aws_config::Region::new(r.clone()));
41        }
42
43        // Load config with timeout
44        let config = tokio::time::timeout(timeout, config_loader.load())
45            .await
46            .map_err(|_| anyhow::anyhow!("Timeout loading AWS config"))?;
47
48        // Double-check region is set
49        if config.region().is_none() {
50            return Err(anyhow::anyhow!("Missing Region"));
51        }
52
53        // Try to get identity with timeout
54        let (account_id, role_arn) =
55            match tokio::time::timeout(timeout, Self::try_get_identity(&config)).await {
56                Ok(Ok((acc, role))) => {
57                    tracing::info!("Loaded identity: account={}, role={}", acc, role);
58                    (acc, role)
59                }
60                Ok(Err(e)) => {
61                    tracing::error!("Failed to get identity: {}", e);
62                    return Err(e);
63                }
64                Err(_) => return Err(anyhow::anyhow!("Timeout getting AWS identity")),
65            };
66
67        let (region_str, auto_detected) = match config.region() {
68            Some(r) => (r.as_ref().to_string(), false),
69            None => {
70                let fastest = Self::find_fastest_region().await?;
71                (fastest, true)
72            }
73        };
74
75        Ok(Self {
76            region: region_str,
77            account_id,
78            role_arn,
79            region_auto_detected: auto_detected,
80        })
81    }
82
83    async fn try_get_identity(config: &aws_config::SdkConfig) -> Result<(String, String)> {
84        let sts_client = aws_sdk_sts::Client::new(config);
85        let identity = sts_client.get_caller_identity().send().await?;
86        let account_id = identity.account().unwrap_or("").to_string();
87        let role_arn = identity.arn().unwrap_or("").to_string();
88        Ok((account_id, role_arn))
89    }
90
91    async fn find_fastest_region() -> Result<String> {
92        use std::time::Instant;
93
94        let regions = [
95            "us-east-1",
96            "us-east-2",
97            "us-west-1",
98            "us-west-2",
99            "af-south-1",
100            "ap-east-1",
101            "ap-south-1",
102            "ap-south-2",
103            "ap-northeast-1",
104            "ap-northeast-2",
105            "ap-northeast-3",
106            "ap-southeast-1",
107            "ap-southeast-2",
108            "ap-southeast-3",
109            "ap-southeast-4",
110            "ca-central-1",
111            "ca-west-1",
112            "eu-central-1",
113            "eu-central-2",
114            "eu-west-1",
115            "eu-west-2",
116            "eu-west-3",
117            "eu-north-1",
118            "eu-south-1",
119            "eu-south-2",
120            "il-central-1",
121            "me-central-1",
122            "me-south-1",
123            "sa-east-1",
124        ];
125
126        let mut tasks = Vec::new();
127
128        for &region in &regions {
129            let region_name = region.to_string();
130            tasks.push(tokio::spawn(async move {
131                let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
132                    .region(aws_config::Region::new(region_name.clone()))
133                    .load()
134                    .await;
135                let s3 = aws_sdk_s3::Client::new(&config);
136                let start = Instant::now();
137                match tokio::time::timeout(
138                    std::time::Duration::from_secs(2),
139                    s3.list_buckets().send(),
140                )
141                .await
142                {
143                    Ok(Ok(_)) => Some((region_name, start.elapsed())),
144                    _ => Some((region_name, std::time::Duration::from_secs(9999))),
145                }
146            }));
147        }
148
149        let results = futures::future::join_all(tasks).await;
150        let mut latencies: Vec<(String, std::time::Duration)> = results
151            .into_iter()
152            .filter_map(|r| r.ok().flatten())
153            .collect();
154
155        latencies.sort_by_key(|(_, d)| *d);
156
157        latencies
158            .first()
159            .map(|(r, _)| r.clone())
160            .ok_or_else(|| anyhow::anyhow!("Could not determine fastest region"))
161    }
162
163    pub fn dummy(region: Option<String>) -> Self {
164        Self {
165            region: region.unwrap_or_default(),
166            account_id: "".to_string(),
167            role_arn: "".to_string(),
168            region_auto_detected: false,
169        }
170    }
171
172    pub async fn get_account_for_profile(profile: &str, region: &str) -> Result<String> {
173        let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
174            .profile_name(profile)
175            .region(aws_config::Region::new(region.to_string()))
176            .load()
177            .await;
178
179        let sts_client = aws_sdk_sts::Client::new(&config);
180        let identity = sts_client.get_caller_identity().send().await?;
181        Ok(identity.account().unwrap_or("").to_string())
182    }
183
184    pub async fn s3_client(&self) -> aws_sdk_s3::Client {
185        let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
186            .region(aws_config::Region::new(self.region.clone()))
187            .load()
188            .await;
189        aws_sdk_s3::Client::new(&config)
190    }
191
192    pub async fn s3_client_with_region(&self, region: &str) -> aws_sdk_s3::Client {
193        let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
194            .region(aws_config::Region::new(region.to_string()))
195            .load()
196            .await;
197        aws_sdk_s3::Client::new(&config)
198    }
199
200    pub async fn cloudformation_client(&self) -> aws_sdk_cloudformation::Client {
201        let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
202            .region(aws_config::Region::new(self.region.clone()))
203            .load()
204            .await;
205        aws_sdk_cloudformation::Client::new(&config)
206    }
207
208    pub async fn cloudtrail_client(&self) -> aws_sdk_cloudtrail::Client {
209        let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
210            .region(aws_config::Region::new(self.region.clone()))
211            .load()
212            .await;
213        aws_sdk_cloudtrail::Client::new(&config)
214    }
215
216    pub async fn lambda_client(&self) -> aws_sdk_lambda::Client {
217        let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
218            .region(aws_config::Region::new(self.region.clone()))
219            .load()
220            .await;
221        aws_sdk_lambda::Client::new(&config)
222    }
223
224    pub async fn iam_client(&self) -> aws_sdk_iam::Client {
225        let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
226            .region(aws_config::Region::new(self.region.clone()))
227            .load()
228            .await;
229        aws_sdk_iam::Client::new(&config)
230    }
231
232    pub async fn ecr_client(&self) -> aws_sdk_ecr::Client {
233        let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
234            .region(aws_config::Region::new(self.region.clone()))
235            .load()
236            .await;
237        aws_sdk_ecr::Client::new(&config)
238    }
239
240    pub async fn ecr_public_client(&self) -> aws_sdk_ecrpublic::Client {
241        let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
242            .region(aws_config::Region::new(self.region.clone()))
243            .load()
244            .await;
245        aws_sdk_ecrpublic::Client::new(&config)
246    }
247
248    pub async fn cloudwatch_client(&self) -> aws_sdk_cloudwatch::Client {
249        let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
250            .region(aws_config::Region::new(self.region.clone()))
251            .load()
252            .await;
253        aws_sdk_cloudwatch::Client::new(&config)
254    }
255
256    pub async fn sqs_client(&self) -> aws_sdk_sqs::Client {
257        let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
258            .region(aws_config::Region::new(self.region.clone()))
259            .load()
260            .await;
261        aws_sdk_sqs::Client::new(&config)
262    }
263
264    pub async fn cloudwatch_logs_client(&self) -> aws_sdk_cloudwatchlogs::Client {
265        let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
266            .region(aws_config::Region::new(self.region.clone()))
267            .load()
268            .await;
269        aws_sdk_cloudwatchlogs::Client::new(&config)
270    }
271
272    pub async fn pipes_client(&self) -> aws_sdk_pipes::Client {
273        let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
274            .region(aws_config::Region::new(self.region.clone()))
275            .load()
276            .await;
277        aws_sdk_pipes::Client::new(&config)
278    }
279
280    pub async fn ec2_client(&self) -> aws_sdk_ec2::Client {
281        let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
282            .region(aws_config::Region::new(self.region.clone()))
283            .load()
284            .await;
285        aws_sdk_ec2::Client::new(&config)
286    }
287
288    pub async fn apigateway_client(&self) -> aws_sdk_apigateway::Client {
289        let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
290            .region(aws_config::Region::new(self.region.clone()))
291            .load()
292            .await;
293        aws_sdk_apigateway::Client::new(&config)
294    }
295
296    pub async fn apigatewayv2_client(&self) -> aws_sdk_apigatewayv2::Client {
297        let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
298            .region(aws_config::Region::new(self.region.clone()))
299            .load()
300            .await;
301        aws_sdk_apigatewayv2::Client::new(&config)
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308
309    #[test]
310    fn test_dummy_config_with_region() {
311        let region = "us-west-2";
312        let config = AwsConfig::dummy(Some(region.to_string()));
313        assert_eq!(config.region, region);
314        assert_eq!(config.account_id, "");
315        assert!(!config.region_auto_detected);
316    }
317
318    #[test]
319    fn test_dummy_config_without_region() {
320        let config = AwsConfig::dummy(None);
321        assert_eq!(config.region, "");
322        assert_eq!(config.account_id, "");
323        assert!(!config.region_auto_detected);
324    }
325
326    #[tokio::test]
327    async fn test_new_fails_without_credentials() {
328        // Clear AWS env vars to simulate no credentials
329        std::env::remove_var("AWS_ACCESS_KEY_ID");
330        std::env::remove_var("AWS_SECRET_ACCESS_KEY");
331        std::env::remove_var("AWS_SESSION_TOKEN");
332        std::env::set_var("AWS_PROFILE", "nonexistent-profile-test");
333
334        let result = AwsConfig::new(Some("us-east-1".to_string())).await;
335
336        // Should fail with credentials error before attempting region detection
337        assert!(result.is_err());
338        let err_str = format!("{}", result.unwrap_err());
339        // Can be credentials error or dispatch failure (both indicate auth issues)
340        assert!(
341            err_str.contains("credentials")
342                || err_str.contains("profile")
343                || err_str.contains("dispatch"),
344            "Expected auth error, got: {}",
345            err_str
346        );
347    }
348
349    #[tokio::test]
350    async fn test_timeout_is_configurable() {
351        std::env::remove_var("AWS_ACCESS_KEY_ID");
352        std::env::remove_var("AWS_SECRET_ACCESS_KEY");
353        std::env::set_var("AWS_PROFILE", "nonexistent-profile-test");
354
355        // Test with very short timeout
356        let result = AwsConfig::new_with_timeout(
357            Some("us-east-1".to_string()),
358            std::time::Duration::from_millis(100),
359        )
360        .await;
361
362        assert!(result.is_err());
363    }
364
365    #[test]
366    fn test_dummy_config_preserves_values() {
367        let region = "eu-west-1";
368        let config = AwsConfig::dummy(Some(region.to_string()));
369        assert_eq!(config.region, region);
370        assert_eq!(config.account_id, "");
371        assert_eq!(config.role_arn, "");
372        assert!(!config.region_auto_detected);
373    }
374}