rusticity_core/
config.rs

1use anyhow::Result;
2
3#[derive(Clone, Debug)]
4pub struct AwsConfig {
5    pub region: String,
6    pub account_id: String,
7    pub role_arn: String,
8    pub region_auto_detected: bool,
9}
10
11impl AwsConfig {
12    pub async fn new(region: Option<String>) -> Result<Self> {
13        Self::new_with_timeout(region, std::time::Duration::from_secs(10)).await
14    }
15
16    pub async fn new_with_timeout(
17        region: Option<String>,
18        timeout: std::time::Duration,
19    ) -> Result<Self> {
20        // Check for region early to avoid IMDS timeout
21        if region.is_none()
22            && std::env::var("AWS_REGION").is_err()
23            && std::env::var("AWS_DEFAULT_REGION").is_err()
24        {
25            return Err(anyhow::anyhow!("Missing Region"));
26        }
27
28        let mut config_loader = aws_config::defaults(aws_config::BehaviorVersion::latest());
29
30        // Use profile from AWS_PROFILE env var if set
31        let profile_to_use = std::env::var("AWS_PROFILE").ok();
32        if let Some(ref profile) = profile_to_use {
33            if !profile.is_empty() {
34                tracing::info!("Using AWS profile: {}", profile);
35                config_loader = config_loader.profile_name(profile);
36            }
37        }
38
39        if let Some(r) = &region {
40            config_loader = config_loader.region(aws_config::Region::new(r.clone()));
41        }
42
43        // Load config with timeout
44        let config = tokio::time::timeout(timeout, config_loader.load())
45            .await
46            .map_err(|_| anyhow::anyhow!("Timeout loading AWS config"))?;
47
48        // Double-check region is set
49        if config.region().is_none() {
50            return Err(anyhow::anyhow!("Missing Region"));
51        }
52
53        // Try to get identity with timeout
54        let (account_id, role_arn) =
55            match tokio::time::timeout(timeout, Self::try_get_identity(&config)).await {
56                Ok(Ok((acc, role))) => {
57                    tracing::info!("Loaded identity: account={}, role={}", acc, role);
58                    (acc, role)
59                }
60                Ok(Err(e)) => {
61                    tracing::error!("Failed to get identity: {}", e);
62                    return Err(e);
63                }
64                Err(_) => return Err(anyhow::anyhow!("Timeout getting AWS identity")),
65            };
66
67        let (region_str, auto_detected) = match config.region() {
68            Some(r) => (r.as_ref().to_string(), false),
69            None => {
70                let fastest = Self::find_fastest_region().await?;
71                (fastest, true)
72            }
73        };
74
75        Ok(Self {
76            region: region_str,
77            account_id,
78            role_arn,
79            region_auto_detected: auto_detected,
80        })
81    }
82
83    async fn try_get_identity(config: &aws_config::SdkConfig) -> Result<(String, String)> {
84        let sts_client = aws_sdk_sts::Client::new(config);
85        let identity = sts_client.get_caller_identity().send().await?;
86        let account_id = identity.account().unwrap_or("").to_string();
87        let role_arn = identity.arn().unwrap_or("").to_string();
88        Ok((account_id, role_arn))
89    }
90
91    async fn find_fastest_region() -> Result<String> {
92        use std::time::Instant;
93
94        let regions = [
95            "us-east-1",
96            "us-east-2",
97            "us-west-1",
98            "us-west-2",
99            "af-south-1",
100            "ap-east-1",
101            "ap-south-1",
102            "ap-south-2",
103            "ap-northeast-1",
104            "ap-northeast-2",
105            "ap-northeast-3",
106            "ap-southeast-1",
107            "ap-southeast-2",
108            "ap-southeast-3",
109            "ap-southeast-4",
110            "ca-central-1",
111            "ca-west-1",
112            "eu-central-1",
113            "eu-central-2",
114            "eu-west-1",
115            "eu-west-2",
116            "eu-west-3",
117            "eu-north-1",
118            "eu-south-1",
119            "eu-south-2",
120            "il-central-1",
121            "me-central-1",
122            "me-south-1",
123            "sa-east-1",
124        ];
125
126        let mut tasks = Vec::new();
127
128        for &region in &regions {
129            let region_name = region.to_string();
130            tasks.push(tokio::spawn(async move {
131                let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
132                    .region(aws_config::Region::new(region_name.clone()))
133                    .load()
134                    .await;
135                let s3 = aws_sdk_s3::Client::new(&config);
136                let start = Instant::now();
137                match tokio::time::timeout(
138                    std::time::Duration::from_secs(2),
139                    s3.list_buckets().send(),
140                )
141                .await
142                {
143                    Ok(Ok(_)) => Some((region_name, start.elapsed())),
144                    _ => Some((region_name, std::time::Duration::from_secs(9999))),
145                }
146            }));
147        }
148
149        let results = futures::future::join_all(tasks).await;
150        let mut latencies: Vec<(String, std::time::Duration)> = results
151            .into_iter()
152            .filter_map(|r| r.ok().flatten())
153            .collect();
154
155        latencies.sort_by_key(|(_, d)| *d);
156
157        latencies
158            .first()
159            .map(|(r, _)| r.clone())
160            .ok_or_else(|| anyhow::anyhow!("Could not determine fastest region"))
161    }
162
163    pub fn dummy(region: Option<String>) -> Self {
164        Self {
165            region: region.unwrap_or_default(),
166            account_id: "".to_string(),
167            role_arn: "".to_string(),
168            region_auto_detected: false,
169        }
170    }
171
172    pub async fn get_account_for_profile(profile: &str, region: &str) -> Result<String> {
173        let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
174            .profile_name(profile)
175            .region(aws_config::Region::new(region.to_string()))
176            .load()
177            .await;
178
179        let sts_client = aws_sdk_sts::Client::new(&config);
180        let identity = sts_client.get_caller_identity().send().await?;
181        Ok(identity.account().unwrap_or("").to_string())
182    }
183
184    pub async fn s3_client(&self) -> aws_sdk_s3::Client {
185        let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
186            .region(aws_config::Region::new(self.region.clone()))
187            .load()
188            .await;
189        aws_sdk_s3::Client::new(&config)
190    }
191
192    pub async fn s3_client_with_region(&self, region: &str) -> aws_sdk_s3::Client {
193        let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
194            .region(aws_config::Region::new(region.to_string()))
195            .load()
196            .await;
197        aws_sdk_s3::Client::new(&config)
198    }
199
200    pub async fn cloudformation_client(&self) -> aws_sdk_cloudformation::Client {
201        let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
202            .region(aws_config::Region::new(self.region.clone()))
203            .load()
204            .await;
205        aws_sdk_cloudformation::Client::new(&config)
206    }
207
208    pub async fn lambda_client(&self) -> aws_sdk_lambda::Client {
209        let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
210            .region(aws_config::Region::new(self.region.clone()))
211            .load()
212            .await;
213        aws_sdk_lambda::Client::new(&config)
214    }
215
216    pub async fn iam_client(&self) -> aws_sdk_iam::Client {
217        let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
218            .region(aws_config::Region::new(self.region.clone()))
219            .load()
220            .await;
221        aws_sdk_iam::Client::new(&config)
222    }
223
224    pub async fn ecr_client(&self) -> aws_sdk_ecr::Client {
225        let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
226            .region(aws_config::Region::new(self.region.clone()))
227            .load()
228            .await;
229        aws_sdk_ecr::Client::new(&config)
230    }
231
232    pub async fn ecr_public_client(&self) -> aws_sdk_ecrpublic::Client {
233        let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
234            .region(aws_config::Region::new(self.region.clone()))
235            .load()
236            .await;
237        aws_sdk_ecrpublic::Client::new(&config)
238    }
239
240    pub async fn cloudwatch_client(&self) -> aws_sdk_cloudwatch::Client {
241        let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
242            .region(aws_config::Region::new(self.region.clone()))
243            .load()
244            .await;
245        aws_sdk_cloudwatch::Client::new(&config)
246    }
247
248    pub async fn sqs_client(&self) -> aws_sdk_sqs::Client {
249        let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
250            .region(aws_config::Region::new(self.region.clone()))
251            .load()
252            .await;
253        aws_sdk_sqs::Client::new(&config)
254    }
255
256    pub async fn cloudwatch_logs_client(&self) -> aws_sdk_cloudwatchlogs::Client {
257        let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
258            .region(aws_config::Region::new(self.region.clone()))
259            .load()
260            .await;
261        aws_sdk_cloudwatchlogs::Client::new(&config)
262    }
263
264    pub async fn pipes_client(&self) -> aws_sdk_pipes::Client {
265        let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
266            .region(aws_config::Region::new(self.region.clone()))
267            .load()
268            .await;
269        aws_sdk_pipes::Client::new(&config)
270    }
271
272    pub async fn ec2_client(&self) -> aws_sdk_ec2::Client {
273        let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
274            .region(aws_config::Region::new(self.region.clone()))
275            .load()
276            .await;
277        aws_sdk_ec2::Client::new(&config)
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    #[test]
286    fn test_dummy_config_with_region() {
287        let region = "us-west-2";
288        let config = AwsConfig::dummy(Some(region.to_string()));
289        assert_eq!(config.region, region);
290        assert_eq!(config.account_id, "");
291        assert!(!config.region_auto_detected);
292    }
293
294    #[test]
295    fn test_dummy_config_without_region() {
296        let config = AwsConfig::dummy(None);
297        assert_eq!(config.region, "");
298        assert_eq!(config.account_id, "");
299        assert!(!config.region_auto_detected);
300    }
301
302    #[tokio::test]
303    async fn test_new_fails_without_credentials() {
304        // Clear AWS env vars to simulate no credentials
305        std::env::remove_var("AWS_ACCESS_KEY_ID");
306        std::env::remove_var("AWS_SECRET_ACCESS_KEY");
307        std::env::remove_var("AWS_SESSION_TOKEN");
308        std::env::set_var("AWS_PROFILE", "nonexistent-profile-test");
309
310        let result = AwsConfig::new(Some("us-east-1".to_string())).await;
311
312        // Should fail with credentials error before attempting region detection
313        assert!(result.is_err());
314        let err_str = format!("{}", result.unwrap_err());
315        // Can be credentials error or dispatch failure (both indicate auth issues)
316        assert!(
317            err_str.contains("credentials")
318                || err_str.contains("profile")
319                || err_str.contains("dispatch"),
320            "Expected auth error, got: {}",
321            err_str
322        );
323    }
324
325    #[tokio::test]
326    async fn test_timeout_is_configurable() {
327        std::env::remove_var("AWS_ACCESS_KEY_ID");
328        std::env::remove_var("AWS_SECRET_ACCESS_KEY");
329        std::env::set_var("AWS_PROFILE", "nonexistent-profile-test");
330
331        // Test with very short timeout
332        let result = AwsConfig::new_with_timeout(
333            Some("us-east-1".to_string()),
334            std::time::Duration::from_millis(100),
335        )
336        .await;
337
338        assert!(result.is_err());
339    }
340
341    #[test]
342    fn test_dummy_config_preserves_values() {
343        let region = "eu-west-1";
344        let config = AwsConfig::dummy(Some(region.to_string()));
345        assert_eq!(config.region, region);
346        assert_eq!(config.account_id, "");
347        assert_eq!(config.role_arn, "");
348        assert!(!config.region_auto_detected);
349    }
350}