Skip to main content

tryaudex_core/
learn.rs

1use std::collections::BTreeSet;
2use std::time::Duration;
3
4use crate::error::{AvError, Result};
5
6/// Maps CloudTrail eventSource (e.g. "s3.amazonaws.com") to IAM service prefix.
7fn event_source_to_service(source: &str) -> Option<&str> {
8    let s = source.strip_suffix(".amazonaws.com")?;
9    Some(match s {
10        "s3" => "s3",
11        "lambda" => "lambda",
12        "dynamodb" => "dynamodb",
13        "ec2" => "ec2",
14        "iam" => "iam",
15        "sts" => "sts",
16        "sqs" => "sqs",
17        "sns" => "sns",
18        "logs" => "logs",
19        "ssm" => "ssm",
20        "cloudformation" => "cloudformation",
21        "cloudwatch" => "cloudwatch",
22        "monitoring" => "cloudwatch",
23        "kms" => "kms",
24        "secretsmanager" => "secretsmanager",
25        "elasticloadbalancing" => "elasticloadbalancing",
26        "autoscaling" => "autoscaling",
27        "ecr" => "ecr",
28        "ecs" => "ecs",
29        "eks" => "eks",
30        "rds" => "rds",
31        "route53" => "route53",
32        "cloudfront" => "cloudfront",
33        "apigateway" => "apigateway",
34        "cognito-idp" | "cognito-identity" => "cognito-idp",
35        "events" => "events",
36        "states" => "states",
37        "kinesis" => "kinesis",
38        "firehose" => "firehose",
39        "athena" => "athena",
40        "glue" => "glue",
41        "redshift" => "redshift",
42        "elasticache" => "elasticache",
43        other => other,
44    })
45}
46
47/// Maps CloudTrail event names to their actual IAM action names.
48/// CloudTrail and IAM use different names for some actions.
49fn cloudtrail_to_iam_action(service: &str, event_name: &str) -> String {
50    let iam_action = match (service, event_name) {
51        // S3 mismatches
52        ("s3", "ListBuckets") => "s3:ListAllMyBuckets",
53        ("s3", "GetBucketAcl") => "s3:GetBucketAcl",
54        ("s3", "HeadBucket") => "s3:ListBucket",
55        ("s3", "ListObjects") => "s3:ListBucket",
56        ("s3", "ListObjectsV2") => "s3:ListBucket",
57        ("s3", "HeadObject") => "s3:GetObject",
58        // EC2 mismatches
59        ("ec2", "DescribeInstanceStatus") => "ec2:DescribeInstanceStatus",
60        // STS
61        ("sts", "GetCallerIdentity") => "sts:GetCallerIdentity",
62        // Default: service:eventName is usually correct
63        _ => return format!("{}:{}", service, event_name),
64    };
65    iam_action.to_string()
66}
67
68/// Result of learning: observed API calls mapped to IAM actions.
69#[derive(Debug, Clone)]
70pub struct LearnedPolicy {
71    /// Unique IAM actions observed (e.g. "s3:ListBuckets")
72    pub actions: BTreeSet<String>,
73}
74
75impl LearnedPolicy {
76    /// Format as a comma-separated --allow string.
77    pub fn to_allow_str(&self) -> String {
78        self.actions.iter().cloned().collect::<Vec<_>>().join(",")
79    }
80
81    /// Format as a TOML profile block for config.
82    pub fn to_profile_toml(&self, name: &str) -> String {
83        format!(
84            "[profiles.{}]\nallow = \"{}\"\ndescription = \"Learned from running command\"\n",
85            name,
86            self.to_allow_str()
87        )
88    }
89}
90
91/// Query CloudTrail for API calls made by a specific access key within a time window.
92pub async fn lookup_events(
93    access_key_id: &str,
94    start_time: chrono::DateTime<chrono::Utc>,
95    end_time: chrono::DateTime<chrono::Utc>,
96    region: Option<&str>,
97) -> Result<LearnedPolicy> {
98    let mut loader = aws_config::defaults(aws_config::BehaviorVersion::latest());
99    if let Some(r) = region {
100        loader = loader.region(aws_config::Region::new(r.to_string()));
101    }
102    let config = loader.load().await;
103    let client = aws_sdk_cloudtrail::Client::new(&config);
104
105    let mut actions = BTreeSet::new();
106    let mut next_token: Option<String> = None;
107
108    loop {
109        let mut req = client
110            .lookup_events()
111            .start_time(aws_sdk_cloudtrail::primitives::DateTime::from_secs(
112                start_time.timestamp(),
113            ))
114            .end_time(aws_sdk_cloudtrail::primitives::DateTime::from_secs(
115                end_time.timestamp(),
116            ))
117            .lookup_attributes(
118                aws_sdk_cloudtrail::types::LookupAttribute::builder()
119                    .attribute_key(aws_sdk_cloudtrail::types::LookupAttributeKey::AccessKeyId)
120                    .attribute_value(access_key_id)
121                    .build()
122                    .map_err(|e| AvError::Sts(format!("CloudTrail lookup build error: {}", e)))?,
123            )
124            .max_results(50);
125
126        if let Some(ref token) = next_token {
127            req = req.next_token(token);
128        }
129
130        let resp = req
131            .send()
132            .await
133            .map_err(|e| AvError::Sts(format!("CloudTrail lookup error: {}", e)))?;
134
135        for event in resp.events() {
136            let event_name = event.event_name().unwrap_or_default();
137            let event_source = event.event_source().unwrap_or_default();
138
139            if let Some(service) = event_source_to_service(event_source) {
140                // Skip the AssumeRole call that Audex itself makes
141                if service == "sts" && event_name == "AssumeRole" {
142                    continue;
143                }
144                actions.insert(cloudtrail_to_iam_action(service, event_name));
145            }
146        }
147
148        next_token = resp.next_token().map(|s| s.to_string());
149        if next_token.is_none() {
150            break;
151        }
152    }
153
154    Ok(LearnedPolicy { actions })
155}
156
157/// Poll CloudTrail until events appear or timeout is reached.
158/// CloudTrail typically has a 5-15 minute delay.
159pub async fn poll_cloudtrail(
160    access_key_id: &str,
161    start_time: chrono::DateTime<chrono::Utc>,
162    end_time: chrono::DateTime<chrono::Utc>,
163    region: Option<&str>,
164    timeout: Duration,
165    poll_interval: Duration,
166) -> Result<LearnedPolicy> {
167    let deadline = std::time::Instant::now() + timeout;
168
169    loop {
170        let result = lookup_events(access_key_id, start_time, end_time, region).await?;
171
172        if !result.actions.is_empty() {
173            return Ok(result);
174        }
175
176        if std::time::Instant::now() + poll_interval > deadline {
177            return Err(AvError::Sts(
178                "Timed out waiting for CloudTrail events. CloudTrail can take 5-15 minutes to propagate. \
179                 Try again later with: tryaudex learn --replay <session-id>"
180                    .to_string(),
181            ));
182        }
183
184        tokio::time::sleep(poll_interval).await;
185    }
186}