1use anyhow::Result;
2
3#[derive(Clone, Debug)]
4pub struct AwsConfig {
5 pub region: String,
6 pub account_id: String,
7 pub role_arn: String,
8 pub region_auto_detected: bool,
9}
10
11impl AwsConfig {
12 pub async fn new(region: Option<String>) -> Result<Self> {
13 Self::new_with_timeout(region, std::time::Duration::from_secs(10)).await
14 }
15
16 pub async fn new_with_timeout(
17 region: Option<String>,
18 timeout: std::time::Duration,
19 ) -> Result<Self> {
20 if region.is_none()
22 && std::env::var("AWS_REGION").is_err()
23 && std::env::var("AWS_DEFAULT_REGION").is_err()
24 {
25 return Err(anyhow::anyhow!("Missing Region"));
26 }
27
28 let mut config_loader = aws_config::defaults(aws_config::BehaviorVersion::latest());
29
30 let profile_to_use = std::env::var("AWS_PROFILE").ok();
32 if let Some(ref profile) = profile_to_use {
33 if !profile.is_empty() {
34 tracing::info!("Using AWS profile: {}", profile);
35 config_loader = config_loader.profile_name(profile);
36 }
37 }
38
39 if let Some(r) = ®ion {
40 config_loader = config_loader.region(aws_config::Region::new(r.clone()));
41 }
42
43 let config = tokio::time::timeout(timeout, config_loader.load())
45 .await
46 .map_err(|_| anyhow::anyhow!("Timeout loading AWS config"))?;
47
48 if config.region().is_none() {
50 return Err(anyhow::anyhow!("Missing Region"));
51 }
52
53 let (account_id, role_arn) =
55 match tokio::time::timeout(timeout, Self::try_get_identity(&config)).await {
56 Ok(Ok((acc, role))) => {
57 tracing::info!("Loaded identity: account={}, role={}", acc, role);
58 (acc, role)
59 }
60 Ok(Err(e)) => {
61 tracing::error!("Failed to get identity: {}", e);
62 return Err(e);
63 }
64 Err(_) => return Err(anyhow::anyhow!("Timeout getting AWS identity")),
65 };
66
67 let (region_str, auto_detected) = match config.region() {
68 Some(r) => (r.as_ref().to_string(), false),
69 None => {
70 let fastest = Self::find_fastest_region().await?;
71 (fastest, true)
72 }
73 };
74
75 Ok(Self {
76 region: region_str,
77 account_id,
78 role_arn,
79 region_auto_detected: auto_detected,
80 })
81 }
82
83 async fn try_get_identity(config: &aws_config::SdkConfig) -> Result<(String, String)> {
84 let sts_client = aws_sdk_sts::Client::new(config);
85 let identity = sts_client.get_caller_identity().send().await?;
86 let account_id = identity.account().unwrap_or("").to_string();
87 let role_arn = identity.arn().unwrap_or("").to_string();
88 Ok((account_id, role_arn))
89 }
90
91 async fn find_fastest_region() -> Result<String> {
92 use std::time::Instant;
93
94 let regions = [
95 "us-east-1",
96 "us-east-2",
97 "us-west-1",
98 "us-west-2",
99 "af-south-1",
100 "ap-east-1",
101 "ap-south-1",
102 "ap-south-2",
103 "ap-northeast-1",
104 "ap-northeast-2",
105 "ap-northeast-3",
106 "ap-southeast-1",
107 "ap-southeast-2",
108 "ap-southeast-3",
109 "ap-southeast-4",
110 "ca-central-1",
111 "ca-west-1",
112 "eu-central-1",
113 "eu-central-2",
114 "eu-west-1",
115 "eu-west-2",
116 "eu-west-3",
117 "eu-north-1",
118 "eu-south-1",
119 "eu-south-2",
120 "il-central-1",
121 "me-central-1",
122 "me-south-1",
123 "sa-east-1",
124 ];
125
126 let mut tasks = Vec::new();
127
128 for ®ion in ®ions {
129 let region_name = region.to_string();
130 tasks.push(tokio::spawn(async move {
131 let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
132 .region(aws_config::Region::new(region_name.clone()))
133 .load()
134 .await;
135 let s3 = aws_sdk_s3::Client::new(&config);
136 let start = Instant::now();
137 match tokio::time::timeout(
138 std::time::Duration::from_secs(2),
139 s3.list_buckets().send(),
140 )
141 .await
142 {
143 Ok(Ok(_)) => Some((region_name, start.elapsed())),
144 _ => Some((region_name, std::time::Duration::from_secs(9999))),
145 }
146 }));
147 }
148
149 let results = futures::future::join_all(tasks).await;
150 let mut latencies: Vec<(String, std::time::Duration)> = results
151 .into_iter()
152 .filter_map(|r| r.ok().flatten())
153 .collect();
154
155 latencies.sort_by_key(|(_, d)| *d);
156
157 latencies
158 .first()
159 .map(|(r, _)| r.clone())
160 .ok_or_else(|| anyhow::anyhow!("Could not determine fastest region"))
161 }
162
163 pub fn dummy(region: Option<String>) -> Self {
164 Self {
165 region: region.unwrap_or_default(),
166 account_id: "".to_string(),
167 role_arn: "".to_string(),
168 region_auto_detected: false,
169 }
170 }
171
172 pub async fn get_account_for_profile(profile: &str, region: &str) -> Result<String> {
173 let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
174 .profile_name(profile)
175 .region(aws_config::Region::new(region.to_string()))
176 .load()
177 .await;
178
179 let sts_client = aws_sdk_sts::Client::new(&config);
180 let identity = sts_client.get_caller_identity().send().await?;
181 Ok(identity.account().unwrap_or("").to_string())
182 }
183
184 pub async fn s3_client(&self) -> aws_sdk_s3::Client {
185 let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
186 .region(aws_config::Region::new(self.region.clone()))
187 .load()
188 .await;
189 aws_sdk_s3::Client::new(&config)
190 }
191
192 pub async fn s3_client_with_region(&self, region: &str) -> aws_sdk_s3::Client {
193 let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
194 .region(aws_config::Region::new(region.to_string()))
195 .load()
196 .await;
197 aws_sdk_s3::Client::new(&config)
198 }
199
200 pub async fn cloudformation_client(&self) -> aws_sdk_cloudformation::Client {
201 let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
202 .region(aws_config::Region::new(self.region.clone()))
203 .load()
204 .await;
205 aws_sdk_cloudformation::Client::new(&config)
206 }
207
208 pub async fn cloudtrail_client(&self) -> aws_sdk_cloudtrail::Client {
209 let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
210 .region(aws_config::Region::new(self.region.clone()))
211 .load()
212 .await;
213 aws_sdk_cloudtrail::Client::new(&config)
214 }
215
216 pub async fn lambda_client(&self) -> aws_sdk_lambda::Client {
217 let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
218 .region(aws_config::Region::new(self.region.clone()))
219 .load()
220 .await;
221 aws_sdk_lambda::Client::new(&config)
222 }
223
224 pub async fn iam_client(&self) -> aws_sdk_iam::Client {
225 let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
226 .region(aws_config::Region::new(self.region.clone()))
227 .load()
228 .await;
229 aws_sdk_iam::Client::new(&config)
230 }
231
232 pub async fn ecr_client(&self) -> aws_sdk_ecr::Client {
233 let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
234 .region(aws_config::Region::new(self.region.clone()))
235 .load()
236 .await;
237 aws_sdk_ecr::Client::new(&config)
238 }
239
240 pub async fn ecr_public_client(&self) -> aws_sdk_ecrpublic::Client {
241 let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
242 .region(aws_config::Region::new(self.region.clone()))
243 .load()
244 .await;
245 aws_sdk_ecrpublic::Client::new(&config)
246 }
247
248 pub async fn cloudwatch_client(&self) -> aws_sdk_cloudwatch::Client {
249 let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
250 .region(aws_config::Region::new(self.region.clone()))
251 .load()
252 .await;
253 aws_sdk_cloudwatch::Client::new(&config)
254 }
255
256 pub async fn sqs_client(&self) -> aws_sdk_sqs::Client {
257 let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
258 .region(aws_config::Region::new(self.region.clone()))
259 .load()
260 .await;
261 aws_sdk_sqs::Client::new(&config)
262 }
263
264 pub async fn cloudwatch_logs_client(&self) -> aws_sdk_cloudwatchlogs::Client {
265 let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
266 .region(aws_config::Region::new(self.region.clone()))
267 .load()
268 .await;
269 aws_sdk_cloudwatchlogs::Client::new(&config)
270 }
271
272 pub async fn pipes_client(&self) -> aws_sdk_pipes::Client {
273 let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
274 .region(aws_config::Region::new(self.region.clone()))
275 .load()
276 .await;
277 aws_sdk_pipes::Client::new(&config)
278 }
279
280 pub async fn ec2_client(&self) -> aws_sdk_ec2::Client {
281 let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
282 .region(aws_config::Region::new(self.region.clone()))
283 .load()
284 .await;
285 aws_sdk_ec2::Client::new(&config)
286 }
287
288 pub async fn apigateway_client(&self) -> aws_sdk_apigateway::Client {
289 let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
290 .region(aws_config::Region::new(self.region.clone()))
291 .load()
292 .await;
293 aws_sdk_apigateway::Client::new(&config)
294 }
295
296 pub async fn apigatewayv2_client(&self) -> aws_sdk_apigatewayv2::Client {
297 let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
298 .region(aws_config::Region::new(self.region.clone()))
299 .load()
300 .await;
301 aws_sdk_apigatewayv2::Client::new(&config)
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308
309 #[test]
310 fn test_dummy_config_with_region() {
311 let region = "us-west-2";
312 let config = AwsConfig::dummy(Some(region.to_string()));
313 assert_eq!(config.region, region);
314 assert_eq!(config.account_id, "");
315 assert!(!config.region_auto_detected);
316 }
317
318 #[test]
319 fn test_dummy_config_without_region() {
320 let config = AwsConfig::dummy(None);
321 assert_eq!(config.region, "");
322 assert_eq!(config.account_id, "");
323 assert!(!config.region_auto_detected);
324 }
325
326 #[tokio::test]
327 async fn test_new_fails_without_credentials() {
328 std::env::remove_var("AWS_ACCESS_KEY_ID");
330 std::env::remove_var("AWS_SECRET_ACCESS_KEY");
331 std::env::remove_var("AWS_SESSION_TOKEN");
332 std::env::set_var("AWS_PROFILE", "nonexistent-profile-test");
333
334 let result = AwsConfig::new(Some("us-east-1".to_string())).await;
335
336 assert!(result.is_err());
338 let err_str = format!("{}", result.unwrap_err());
339 assert!(
341 err_str.contains("credentials")
342 || err_str.contains("profile")
343 || err_str.contains("dispatch"),
344 "Expected auth error, got: {}",
345 err_str
346 );
347 }
348
349 #[tokio::test]
350 async fn test_timeout_is_configurable() {
351 std::env::remove_var("AWS_ACCESS_KEY_ID");
352 std::env::remove_var("AWS_SECRET_ACCESS_KEY");
353 std::env::set_var("AWS_PROFILE", "nonexistent-profile-test");
354
355 let result = AwsConfig::new_with_timeout(
357 Some("us-east-1".to_string()),
358 std::time::Duration::from_millis(100),
359 )
360 .await;
361
362 assert!(result.is_err());
363 }
364
365 #[test]
366 fn test_dummy_config_preserves_values() {
367 let region = "eu-west-1";
368 let config = AwsConfig::dummy(Some(region.to_string()));
369 assert_eq!(config.region, region);
370 assert_eq!(config.account_id, "");
371 assert_eq!(config.role_arn, "");
372 assert!(!config.region_auto_detected);
373 }
374}