1use std::collections::BTreeSet;
2use std::time::Duration;
3
4use crate::error::{AvError, Result};
5
6pub fn event_source_to_service(source: &str) -> Option<&str> {
9 let s = source
11 .strip_suffix(".amazonaws.com.cn")
12 .or_else(|| source.strip_suffix(".amazonaws.com"))?;
13 Some(match s {
14 "s3" => "s3",
15 "lambda" => "lambda",
16 "dynamodb" => "dynamodb",
17 "ec2" => "ec2",
18 "iam" => "iam",
19 "sts" => "sts",
20 "sqs" => "sqs",
21 "sns" => "sns",
22 "logs" => "logs",
23 "ssm" => "ssm",
24 "cloudformation" => "cloudformation",
25 "cloudwatch" => "cloudwatch",
26 "monitoring" => "cloudwatch",
27 "kms" => "kms",
28 "secretsmanager" => "secretsmanager",
29 "elasticloadbalancing" => "elasticloadbalancing",
30 "autoscaling" => "autoscaling",
31 "ecr" => "ecr",
32 "ecs" => "ecs",
33 "eks" => "eks",
34 "rds" => "rds",
35 "route53" => "route53",
36 "cloudfront" => "cloudfront",
37 "apigateway" => "apigateway",
38 "cognito-idp" | "cognito-identity" => "cognito-idp",
39 "events" => "events",
40 "states" => "states",
41 "kinesis" => "kinesis",
42 "firehose" => "firehose",
43 "athena" => "athena",
44 "glue" => "glue",
45 "redshift" => "redshift",
46 "elasticache" => "elasticache",
47 other => {
48 tracing::warn!(
49 event_source = %source,
50 prefix = %other,
51 "Unknown CloudTrail eventSource — using raw prefix as IAM service name. \
52 This may not be a valid IAM service prefix and could produce an unusable policy."
53 );
54 other
55 }
56 })
57}
58
59fn cloudtrail_to_iam_action(service: &str, event_name: &str) -> String {
62 let iam_action = match (service, event_name) {
63 ("s3", "ListBuckets") => "s3:ListAllMyBuckets",
65 ("s3", "GetBucketAcl") => "s3:GetBucketAcl",
66 ("s3", "HeadBucket") => "s3:ListBucket",
67 ("s3", "ListObjects") => "s3:ListBucket",
68 ("s3", "ListObjectsV2") => "s3:ListBucket",
69 ("s3", "HeadObject") => "s3:GetObject",
70 ("ec2", "DescribeInstanceStatus") => "ec2:DescribeInstanceStatus",
72 ("sts", "GetCallerIdentity") => "sts:GetCallerIdentity",
74 ("apigateway", "CreateRestApi") => "apigateway:POST",
76 ("apigateway", "GetRestApi") => "apigateway:GET",
77 ("apigateway", "GetRestApis") => "apigateway:GET",
78 ("apigateway", "DeleteRestApi") => "apigateway:DELETE",
79 ("apigateway", "UpdateRestApi") => "apigateway:PATCH",
80 ("apigateway", "CreateDeployment") => "apigateway:POST",
81 ("apigateway", "CreateStage") => "apigateway:POST",
82 ("apigateway", "UpdateStage") => "apigateway:PATCH",
83 ("apigateway", "DeleteStage") => "apigateway:DELETE",
84 ("apigateway", "GetStages") => "apigateway:GET",
85 ("apigateway", "GetResources") => "apigateway:GET",
86 ("apigateway", "CreateResource") => "apigateway:POST",
87 ("ecs", "DescribeClusters") => "ecs:DescribeClusters",
89 ("ecs", "DescribeServices") => "ecs:DescribeServices",
90 ("ecs", "DescribeTaskDefinition") => "ecs:DescribeTaskDefinition",
91 ("ecs", "DescribeTasks") => "ecs:DescribeTasks",
92 ("logs", "CreateLogGroup") => "logs:CreateLogGroup",
94 ("logs", "FilterLogEvents") => "logs:FilterLogEvents",
95 ("elasticloadbalancing", "CreateLoadBalancer") => "elasticloadbalancing:CreateLoadBalancer",
97 ("elasticloadbalancing", "DescribeLoadBalancers") => {
98 "elasticloadbalancing:DescribeLoadBalancers"
99 }
100 ("elasticloadbalancing", "CreateTargetGroup") => "elasticloadbalancing:CreateTargetGroup",
101 ("elasticloadbalancing", "DescribeTargetGroups") => {
102 "elasticloadbalancing:DescribeTargetGroups"
103 }
104 ("states", "CreateStateMachine") => "states:CreateStateMachine",
106 ("states", "StartExecution") => "states:StartExecution",
107 ("states", "DescribeExecution") => "states:DescribeExecution",
108 _ => return default_action_for_service(service, event_name),
110 };
111 iam_action.to_string()
112}
113
114fn default_action_for_service(service: &str, event_name: &str) -> String {
124 match service {
125 "apigateway" => {
126 let verb = if event_name.starts_with("Get") || event_name.starts_with("Describe") {
129 "GET"
130 } else if event_name.starts_with("Create") || event_name.starts_with("Put") {
131 "POST"
132 } else if event_name.starts_with("Update") {
133 "PATCH"
134 } else if event_name.starts_with("Delete") {
135 "DELETE"
136 } else {
137 "*"
140 };
141 format!("apigateway:{}", verb)
142 }
143 _ => format!("{}:{}", service, event_name),
144 }
145}
146
147#[derive(Debug, Clone)]
149pub struct LearnedPolicy {
150 pub actions: BTreeSet<String>,
152}
153
154impl LearnedPolicy {
155 pub fn to_allow_str(&self) -> String {
157 self.actions.iter().cloned().collect::<Vec<_>>().join(",")
158 }
159
160 pub fn to_profile_toml(&self, name: &str) -> String {
162 format!(
163 "[profiles.{}]\nallow = \"{}\"\ndescription = \"Learned from running command\"\n",
164 name,
165 self.to_allow_str()
166 )
167 }
168}
169
170pub async fn lookup_events(
177 access_key_id: &str,
178 start_time: chrono::DateTime<chrono::Utc>,
179 end_time: chrono::DateTime<chrono::Utc>,
180 region: Option<&str>,
181) -> Result<LearnedPolicy> {
182 let mut loader = aws_config::defaults(aws_config::BehaviorVersion::latest());
183 if let Some(r) = region {
184 loader = loader.region(aws_config::Region::new(r.to_string()));
185 }
186 let config = loader.load().await;
187 let client = aws_sdk_cloudtrail::Client::new(&config);
188
189 tracing::warn!(
190 region = region.unwrap_or("(default)"),
191 "CloudTrail learn queries a single region — API calls in other regions will be missed"
192 );
193
194 let mut actions = BTreeSet::new();
195 let mut next_token: Option<String> = None;
196
197 loop {
198 let mut req = client
199 .lookup_events()
200 .start_time(aws_sdk_cloudtrail::primitives::DateTime::from_secs(
201 start_time.timestamp(),
202 ))
203 .end_time(aws_sdk_cloudtrail::primitives::DateTime::from_secs(
204 end_time.timestamp(),
205 ))
206 .lookup_attributes(
207 aws_sdk_cloudtrail::types::LookupAttribute::builder()
208 .attribute_key(aws_sdk_cloudtrail::types::LookupAttributeKey::AccessKeyId)
209 .attribute_value(access_key_id)
210 .build()
211 .map_err(|e| AvError::Sts(format!("CloudTrail lookup build error: {}", e)))?,
212 )
213 .max_results(50);
214
215 if let Some(ref token) = next_token {
216 req = req.next_token(token);
217 }
218
219 let resp = req
220 .send()
221 .await
222 .map_err(|e| AvError::Sts(format!("CloudTrail lookup error: {}", e)))?;
223
224 for event in resp.events() {
225 let event_name = event.event_name().unwrap_or_default();
226 let event_source = event.event_source().unwrap_or_default();
227
228 if let Some(service) = event_source_to_service(event_source) {
229 if service == "sts" && event_name == "AssumeRole" {
231 continue;
232 }
233 actions.insert(cloudtrail_to_iam_action(service, event_name));
234 }
235 }
236
237 next_token = resp.next_token().map(|s| s.to_string());
238 if next_token.is_none() {
239 break;
240 }
241 }
242
243 Ok(LearnedPolicy { actions })
244}
245
246pub async fn poll_cloudtrail(
261 access_key_id: &str,
262 start_time: chrono::DateTime<chrono::Utc>,
263 end_time: chrono::DateTime<chrono::Utc>,
264 region: Option<&str>,
265 timeout: Duration,
266 poll_interval: Duration,
267) -> Result<LearnedPolicy> {
268 let deadline = std::time::Instant::now() + timeout;
269
270 loop {
271 let result = lookup_events(access_key_id, start_time, end_time, region).await?;
272
273 if !result.actions.is_empty() {
274 return Ok(result);
275 }
276
277 if std::time::Instant::now() + poll_interval > deadline {
278 return Err(AvError::Sts(
279 "Timed out waiting for CloudTrail events. CloudTrail can take 5-15 minutes to propagate. \
280 Try again later — CloudTrail events may not be available yet."
281 .to_string(),
282 ));
283 }
284
285 tokio::time::sleep(poll_interval).await;
286 }
287}