Skip to main content

tryaudex_core/
estimate.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5use crate::policy::ScopedPolicy;
6
7/// Cost estimate for a set of IAM actions.
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct CostEstimate {
10    /// Per-service cost breakdown.
11    pub services: Vec<ServiceEstimate>,
12    /// Total estimated cost for the session.
13    pub total_min: f64,
14    pub total_max: f64,
15    /// Risk level: "low", "medium", "high".
16    pub risk_level: String,
17    /// Human-readable warnings.
18    pub warnings: Vec<String>,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ServiceEstimate {
23    pub service: String,
24    pub actions: Vec<String>,
25    pub min_cost: f64,
26    pub max_cost: f64,
27    pub notes: String,
28}
29
30/// Known cost profiles for common AWS services and actions.
31/// These are rough estimates based on typical usage patterns.
32/// Real costs depend on usage volume, region, and data transfer.
33fn cost_profiles() -> HashMap<&'static str, ServiceCostProfile> {
34    HashMap::from([
35        ("s3", ServiceCostProfile {
36            read_actions: &["GetObject", "ListBucket", "ListAllMyBuckets", "GetBucketLocation", "HeadObject"],
37            write_actions: &["PutObject", "DeleteObject", "CreateBucket", "CopyObject"],
38            read_cost_per_1k: 0.0004,   // $0.0004 per 1,000 GET requests
39            write_cost_per_1k: 0.005,    // $0.005 per 1,000 PUT requests
40            data_cost_per_gb: 0.023,     // $0.023/GB stored
41            typical_requests: 100,
42            notes: "S3 pricing: GET $0.0004/1K, PUT $0.005/1K, storage $0.023/GB/mo",
43        }),
44        ("lambda", ServiceCostProfile {
45            read_actions: &["GetFunction", "ListFunctions", "GetFunctionConfiguration"],
46            write_actions: &["InvokeFunction", "UpdateFunctionCode", "UpdateFunctionConfiguration", "CreateFunction"],
47            read_cost_per_1k: 0.0,
48            write_cost_per_1k: 0.20,     // ~$0.20 per 1M invocations = $0.0002/1K
49            data_cost_per_gb: 0.0000167, // $0.0000166667/GB-second
50            typical_requests: 10,
51            notes: "Lambda: $0.20/1M requests + $0.0000166667/GB-s compute",
52        }),
53        ("dynamodb", ServiceCostProfile {
54            read_actions: &["GetItem", "Query", "Scan", "BatchGetItem", "DescribeTable", "ListTables"],
55            write_actions: &["PutItem", "UpdateItem", "DeleteItem", "BatchWriteItem"],
56            read_cost_per_1k: 0.00025,   // $0.25 per 1M RRU
57            write_cost_per_1k: 0.00125,  // $1.25 per 1M WRU
58            data_cost_per_gb: 0.25,
59            typical_requests: 100,
60            notes: "DynamoDB on-demand: $0.25/1M RRU, $1.25/1M WRU",
61        }),
62        ("ec2", ServiceCostProfile {
63            read_actions: &["DescribeInstances", "DescribeSecurityGroups", "DescribeSubnets", "DescribeVpcs", "DescribeImages"],
64            write_actions: &["RunInstances", "TerminateInstances", "StartInstances", "StopInstances", "CreateSecurityGroup"],
65            read_cost_per_1k: 0.0,
66            write_cost_per_1k: 0.0,      // EC2 API calls are free, instances cost
67            data_cost_per_gb: 0.0,
68            typical_requests: 10,
69            notes: "EC2 API calls are free. Instance costs: $0.0116/hr (t3.micro) to $3.84/hr (p3.16xlarge)",
70        }),
71        ("iam", ServiceCostProfile {
72            read_actions: &["GetRole", "GetPolicy", "GetPolicyVersion", "ListRolePolicies", "ListAttachedRolePolicies"],
73            write_actions: &["CreateRole", "DeleteRole", "AttachRolePolicy", "DetachRolePolicy", "PutRolePolicy", "DeleteRolePolicy", "PassRole"],
74            read_cost_per_1k: 0.0,
75            write_cost_per_1k: 0.0,
76            data_cost_per_gb: 0.0,
77            typical_requests: 10,
78            notes: "IAM API calls are free. WARNING: IAM changes affect account security",
79        }),
80        ("sqs", ServiceCostProfile {
81            read_actions: &["ReceiveMessage", "GetQueueAttributes", "ListQueues"],
82            write_actions: &["SendMessage", "DeleteMessage", "CreateQueue"],
83            read_cost_per_1k: 0.0004,
84            write_cost_per_1k: 0.0004,
85            data_cost_per_gb: 0.0,
86            typical_requests: 100,
87            notes: "SQS: $0.40/1M requests (first 1M free)",
88        }),
89        ("sns", ServiceCostProfile {
90            read_actions: &["ListTopics", "GetTopicAttributes"],
91            write_actions: &["Publish"],
92            read_cost_per_1k: 0.0,
93            write_cost_per_1k: 0.0005,
94            data_cost_per_gb: 0.0,
95            typical_requests: 10,
96            notes: "SNS: $0.50/1M publishes",
97        }),
98        ("ecr", ServiceCostProfile {
99            read_actions: &["GetAuthorizationToken", "BatchCheckLayerAvailability", "GetDownloadUrlForLayer", "BatchGetImage", "DescribeRepositories"],
100            write_actions: &["PutImage", "InitiateLayerUpload", "UploadLayerPart", "CompleteLayerUpload", "CreateRepository"],
101            read_cost_per_1k: 0.0,
102            write_cost_per_1k: 0.0,
103            data_cost_per_gb: 0.10,
104            typical_requests: 10,
105            notes: "ECR: $0.10/GB/month storage. Data transfer charges apply",
106        }),
107        ("logs", ServiceCostProfile {
108            read_actions: &["GetLogEvents", "DescribeLogGroups", "DescribeLogStreams", "FilterLogEvents"],
109            write_actions: &["PutLogEvents", "CreateLogGroup", "CreateLogStream"],
110            read_cost_per_1k: 0.005,
111            write_cost_per_1k: 0.0,
112            data_cost_per_gb: 0.50,
113            typical_requests: 50,
114            notes: "CloudWatch Logs: $0.50/GB ingested, $0.005/1K queries",
115        }),
116        ("sts", ServiceCostProfile {
117            read_actions: &["GetCallerIdentity"],
118            write_actions: &["AssumeRole"],
119            read_cost_per_1k: 0.0,
120            write_cost_per_1k: 0.0,
121            data_cost_per_gb: 0.0,
122            typical_requests: 1,
123            notes: "STS API calls are free",
124        }),
125        ("cloudformation", ServiceCostProfile {
126            read_actions: &["DescribeStacks", "ListStacks", "GetTemplate"],
127            write_actions: &["CreateStack", "UpdateStack", "DeleteStack"],
128            read_cost_per_1k: 0.0,
129            write_cost_per_1k: 0.0,
130            data_cost_per_gb: 0.0,
131            typical_requests: 10,
132            notes: "CloudFormation: free for AWS resources, $0.0009/handler operation for third-party",
133        }),
134    ])
135}
136
137struct ServiceCostProfile {
138    read_actions: &'static [&'static str],
139    write_actions: &'static [&'static str],
140    read_cost_per_1k: f64,
141    write_cost_per_1k: f64,
142    data_cost_per_gb: f64,
143    typical_requests: u32,
144    notes: &'static str,
145}
146
147/// Estimate the potential cost of running a session with the given policy.
148pub fn estimate(policy: &ScopedPolicy, ttl_seconds: u64) -> CostEstimate {
149    let profiles = cost_profiles();
150    let mut services = Vec::new();
151    let mut warnings = Vec::new();
152    let mut total_min = 0.0;
153    let mut total_max = 0.0;
154
155    // Group actions by service
156    let mut by_service: HashMap<String, Vec<String>> = HashMap::new();
157    for action in &policy.actions {
158        by_service
159            .entry(action.service.clone())
160            .or_default()
161            .push(action.action.clone());
162    }
163
164    let ttl_hours = ttl_seconds as f64 / 3600.0;
165
166    for (service, actions) in &by_service {
167        if let Some(profile) = profiles.get(service.as_str()) {
168            let mut read_count = 0;
169            let mut write_count = 0;
170
171            for action in actions {
172                if action == "*" {
173                    // Wildcard — assume mix of read + write
174                    read_count += profile.typical_requests;
175                    write_count += profile.typical_requests;
176                    warnings.push(format!(
177                        "{}:* grants all actions — cost depends on actual usage",
178                        service
179                    ));
180                } else if profile
181                    .read_actions
182                    .iter()
183                    .any(|r| action.starts_with(r) || *r == action.as_str())
184                {
185                    read_count += profile.typical_requests;
186                } else if profile
187                    .write_actions
188                    .iter()
189                    .any(|w| action.starts_with(w) || *w == action.as_str())
190                {
191                    write_count += profile.typical_requests;
192                } else {
193                    // Unknown action, assume write
194                    write_count += profile.typical_requests / 2;
195                }
196            }
197
198            // Scale by TTL (longer sessions = more potential requests)
199            let scale = (ttl_hours * 0.5).max(1.0); // at least 1x
200            let scaled_reads = (read_count as f64 * scale) as u32;
201            let scaled_writes = (write_count as f64 * scale) as u32;
202
203            let read_cost = (scaled_reads as f64 / 1000.0) * profile.read_cost_per_1k;
204            let write_cost = (scaled_writes as f64 / 1000.0) * profile.write_cost_per_1k;
205            let min_cost = read_cost + write_cost;
206            // Max assumes 10x typical usage
207            let max_cost = min_cost * 10.0 + profile.data_cost_per_gb * 0.1;
208
209            total_min += min_cost;
210            total_max += max_cost;
211
212            services.push(ServiceEstimate {
213                service: service.clone(),
214                actions: actions.clone(),
215                min_cost,
216                max_cost,
217                notes: profile.notes.to_string(),
218            });
219
220            // High-cost warnings
221            if service == "ec2"
222                && actions
223                    .iter()
224                    .any(|a| a.contains("RunInstances") || a == "*")
225            {
226                warnings.push(
227                    "ec2:RunInstances can launch instances costing up to $3.84/hr".to_string(),
228                );
229            }
230        } else {
231            services.push(ServiceEstimate {
232                service: service.clone(),
233                actions: actions.clone(),
234                min_cost: 0.0,
235                max_cost: 0.0,
236                notes: "No cost estimate available for this service".to_string(),
237            });
238        }
239    }
240
241    // Determine risk level
242    let risk_level = if total_max > 10.0 || !warnings.is_empty() {
243        "high"
244    } else if total_max > 1.0 {
245        "medium"
246    } else {
247        "low"
248    }
249    .to_string();
250
251    // IAM-specific warnings
252    if by_service.contains_key("iam") {
253        warnings.push("IAM actions can modify account security — review carefully".to_string());
254    }
255
256    services.sort_by(|a, b| {
257        b.max_cost
258            .partial_cmp(&a.max_cost)
259            .unwrap_or(std::cmp::Ordering::Equal)
260    });
261
262    CostEstimate {
263        services,
264        total_min,
265        total_max,
266        risk_level,
267        warnings,
268    }
269}
270
271/// Format a cost estimate as human-readable text.
272pub fn format_text(est: &CostEstimate) -> String {
273    let mut out = String::new();
274
275    out.push_str(&format!(
276        "Estimated cost: ${:.4} — ${:.4}\n",
277        est.total_min, est.total_max
278    ));
279    out.push_str(&format!("Risk level: {}\n\n", est.risk_level));
280
281    for svc in &est.services {
282        out.push_str(&format!(
283            "  {} (${:.4} — ${:.4})\n",
284            svc.service, svc.min_cost, svc.max_cost
285        ));
286        out.push_str(&format!("    Actions: {}\n", svc.actions.join(", ")));
287        out.push_str(&format!("    {}\n", svc.notes));
288    }
289
290    if !est.warnings.is_empty() {
291        out.push_str("\nWarnings:\n");
292        for w in &est.warnings {
293            out.push_str(&format!("  ! {}\n", w));
294        }
295    }
296
297    out
298}
299
300/// Format a cost estimate as JSON.
301pub fn format_json(est: &CostEstimate) -> String {
302    serde_json::to_string_pretty(est).unwrap_or_else(|_| "{}".to_string())
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308
309    #[test]
310    fn test_estimate_s3_readonly() {
311        let policy = ScopedPolicy::from_allow_str("s3:GetObject,s3:ListBucket").unwrap();
312        let est = estimate(&policy, 900);
313        assert_eq!(est.services.len(), 1);
314        assert_eq!(est.services[0].service, "s3");
315        assert!(est.total_min >= 0.0);
316        assert_eq!(est.risk_level, "low");
317    }
318
319    #[test]
320    fn test_estimate_ec2_high_risk() {
321        let policy =
322            ScopedPolicy::from_allow_str("ec2:RunInstances,ec2:DescribeInstances").unwrap();
323        let est = estimate(&policy, 3600);
324        assert_eq!(est.risk_level, "high");
325        assert!(est.warnings.iter().any(|w| w.contains("RunInstances")));
326    }
327
328    #[test]
329    fn test_estimate_iam_warning() {
330        let policy = ScopedPolicy::from_allow_str("iam:CreateRole").unwrap();
331        let est = estimate(&policy, 900);
332        assert!(est.warnings.iter().any(|w| w.contains("IAM")));
333    }
334
335    #[test]
336    fn test_estimate_wildcard() {
337        let policy = ScopedPolicy::from_allow_str("s3:*").unwrap();
338        let est = estimate(&policy, 900);
339        assert!(est.warnings.iter().any(|w| w.contains("s3:*")));
340    }
341
342    #[test]
343    fn test_estimate_multi_service() {
344        let policy =
345            ScopedPolicy::from_allow_str("s3:GetObject,lambda:InvokeFunction,dynamodb:Query")
346                .unwrap();
347        let est = estimate(&policy, 900);
348        assert_eq!(est.services.len(), 3);
349    }
350
351    #[test]
352    fn test_estimate_unknown_service() {
353        let policy = ScopedPolicy::from_allow_str("xray:GetTraceSummaries").unwrap();
354        let est = estimate(&policy, 900);
355        assert_eq!(est.services.len(), 1);
356        assert!(est.services[0].notes.contains("No cost estimate"));
357    }
358
359    #[test]
360    fn test_format_text() {
361        let policy = ScopedPolicy::from_allow_str("s3:GetObject").unwrap();
362        let est = estimate(&policy, 900);
363        let text = format_text(&est);
364        assert!(text.contains("Estimated cost"));
365        assert!(text.contains("s3"));
366    }
367
368    #[test]
369    fn test_format_json() {
370        let policy = ScopedPolicy::from_allow_str("s3:GetObject").unwrap();
371        let est = estimate(&policy, 900);
372        let json = format_json(&est);
373        assert!(json.contains("total_min"));
374        assert!(json.contains("risk_level"));
375    }
376
377    #[test]
378    fn test_longer_ttl_scales_cost() {
379        let policy = ScopedPolicy::from_allow_str("s3:GetObject,s3:PutObject").unwrap();
380        let short = estimate(&policy, 900); // 15 min
381        let long = estimate(&policy, 14400); // 4 hours
382                                             // Longer TTL should result in higher max cost
383        assert!(long.total_max >= short.total_max);
384    }
385
386    #[test]
387    fn test_free_services() {
388        let policy = ScopedPolicy::from_allow_str("sts:GetCallerIdentity").unwrap();
389        let est = estimate(&policy, 900);
390        assert_eq!(est.total_min, 0.0);
391    }
392}