1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3
4use crate::error::{AvError, Result};
5use crate::session::{Session, SessionStatus};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct ApiEvent {
10 pub timestamp: DateTime<Utc>,
11 pub event_name: String,
12 pub service: String,
13 pub source_ip: Option<String>,
14 pub user_agent: Option<String>,
15 pub error_code: Option<String>,
16 pub error_message: Option<String>,
17 pub read_only: bool,
18}
19
20#[derive(Debug, Default)]
22pub struct WatchState {
23 pub events: Vec<ApiEvent>,
24 pub last_poll: Option<DateTime<Utc>>,
25 pub total_api_calls: usize,
26 pub error_count: usize,
27 pub services_used: Vec<String>,
28}
29
30impl WatchState {
31 pub fn merge(&mut self, new_events: Vec<ApiEvent>) {
33 for event in new_events {
34 let is_dup = self.events.iter().any(|e| {
35 e.timestamp == event.timestamp && e.event_name == event.event_name
36 });
37 if !is_dup {
38 if event.error_code.is_some() {
39 self.error_count += 1;
40 }
41 let svc = event.service.clone();
42 self.events.push(event);
43 self.total_api_calls += 1;
44 if !self.services_used.contains(&svc) {
45 self.services_used.push(svc);
46 }
47 }
48 }
49 self.last_poll = Some(Utc::now());
50 }
51}
52
53pub async fn poll_cloudtrail(
56 access_key_id: &str,
57 start_time: DateTime<Utc>,
58 region: Option<&str>,
59) -> Result<Vec<ApiEvent>> {
60 let mut loader = aws_config::defaults(aws_config::BehaviorVersion::latest());
61 if let Some(region) = region {
62 loader = loader.region(aws_config::Region::new(region.to_string()));
63 }
64 let config = loader.load().await;
65 let client = aws_sdk_cloudtrail::Client::new(&config);
66
67 let start = aws_sdk_cloudtrail::primitives::DateTime::from_secs(start_time.timestamp());
68
69 let result = client
70 .lookup_events()
71 .lookup_attributes(
72 aws_sdk_cloudtrail::types::LookupAttribute::builder()
73 .attribute_key(aws_sdk_cloudtrail::types::LookupAttributeKey::AccessKeyId)
74 .attribute_value(access_key_id)
75 .build()
76 .map_err(|e| AvError::Sts(format!("CloudTrail attribute build error: {}", e)))?,
77 )
78 .start_time(start)
79 .max_results(50)
80 .send()
81 .await
82 .map_err(|e| AvError::Sts(format!("CloudTrail LookupEvents error: {}", e)))?;
83
84 let events = result.events().iter().filter_map(|e| {
85 let event_name = e.event_name()?.to_string();
86
87 let service = e.event_source()
89 .map(|s| s.split('.').next().unwrap_or(s).to_string())
90 .unwrap_or_else(|| "unknown".to_string());
91
92 let read_only = e.read_only()
93 .and_then(|r| r.parse::<bool>().ok())
94 .unwrap_or(true);
95
96 let (error_code, error_message) = e.cloud_trail_event()
98 .and_then(|json| serde_json::from_str::<serde_json::Value>(json).ok())
99 .map(|v| {
100 let code = v.get("errorCode").and_then(|c| c.as_str()).map(String::from);
101 let msg = v.get("errorMessage").and_then(|m| m.as_str()).map(String::from);
102 (code, msg)
103 })
104 .unwrap_or((None, None));
105
106 let timestamp = e.event_time().map(|t| {
107 DateTime::from_timestamp(t.secs(), t.subsec_nanos())
108 .unwrap_or_else(Utc::now)
109 }).unwrap_or_else(Utc::now);
110
111 let source_ip = e.cloud_trail_event()
112 .and_then(|json| serde_json::from_str::<serde_json::Value>(json).ok())
113 .and_then(|v| v.get("sourceIPAddress").and_then(|s| s.as_str()).map(String::from));
114
115 let user_agent = e.cloud_trail_event()
116 .and_then(|json| serde_json::from_str::<serde_json::Value>(json).ok())
117 .and_then(|v| v.get("userAgent").and_then(|s| s.as_str()).map(String::from));
118
119 Some(ApiEvent {
120 timestamp,
121 event_name,
122 service,
123 source_ip,
124 user_agent,
125 error_code,
126 error_message,
127 read_only,
128 })
129 }).collect();
130
131 Ok(events)
132}
133
134pub fn is_watchable(session: &Session) -> bool {
136 session.status == SessionStatus::Active && !session.is_expired()
137}
138
139pub fn format_event(event: &ApiEvent) -> String {
141 let ts = event.timestamp.format("%H:%M:%S");
142 let rw = if event.read_only { "R" } else { "W" };
143 let status = if let Some(ref code) = event.error_code {
144 format!(" [ERR: {}]", code)
145 } else {
146 String::new()
147 };
148 format!(
149 "[{}] {} {:<12} {}{}",
150 ts, rw, event.service, event.event_name, status
151 )
152}
153
154pub fn format_summary(state: &WatchState, session: &Session) -> String {
156 let mut out = String::new();
157 out.push_str(&format!(
158 "Session {} | {} API calls | {} errors | {} services | {}s remaining\n",
159 &session.id[..8],
160 state.total_api_calls,
161 state.error_count,
162 state.services_used.len(),
163 session.remaining_seconds()
164 ));
165 if !state.services_used.is_empty() {
166 out.push_str(&format!("Services: {}\n", state.services_used.join(", ")));
167 }
168 out
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174
175 fn make_event(name: &str, service: &str, error: Option<&str>) -> ApiEvent {
176 ApiEvent {
177 timestamp: Utc::now(),
178 event_name: name.to_string(),
179 service: service.to_string(),
180 source_ip: Some("10.0.0.1".to_string()),
181 user_agent: Some("aws-cli/2.0".to_string()),
182 error_code: error.map(String::from),
183 error_message: None,
184 read_only: true,
185 }
186 }
187
188 #[test]
189 fn test_watch_state_merge() {
190 let mut state = WatchState::default();
191 let events = vec![
192 make_event("GetObject", "s3", None),
193 make_event("ListBuckets", "s3", None),
194 ];
195 state.merge(events);
196 assert_eq!(state.total_api_calls, 2);
197 assert_eq!(state.error_count, 0);
198 assert_eq!(state.services_used, vec!["s3"]);
199 }
200
201 #[test]
202 fn test_watch_state_dedup() {
203 let mut state = WatchState::default();
204 let ts = Utc::now();
205 let event = ApiEvent {
206 timestamp: ts,
207 event_name: "GetObject".to_string(),
208 service: "s3".to_string(),
209 source_ip: None,
210 user_agent: None,
211 error_code: None,
212 error_message: None,
213 read_only: true,
214 };
215 state.merge(vec![event.clone()]);
216 state.merge(vec![event]);
217 assert_eq!(state.total_api_calls, 1); }
219
220 #[test]
221 fn test_watch_state_error_count() {
222 let mut state = WatchState::default();
223 let events = vec![
224 make_event("GetObject", "s3", None),
225 make_event("PutObject", "s3", Some("AccessDenied")),
226 ];
227 state.merge(events);
228 assert_eq!(state.error_count, 1);
229 }
230
231 #[test]
232 fn test_watch_state_multi_service() {
233 let mut state = WatchState::default();
234 let events = vec![
235 make_event("GetObject", "s3", None),
236 make_event("GetFunction", "lambda", None),
237 make_event("ListBuckets", "s3", None),
238 ];
239 state.merge(events);
240 assert_eq!(state.services_used.len(), 2);
241 assert!(state.services_used.contains(&"s3".to_string()));
242 assert!(state.services_used.contains(&"lambda".to_string()));
243 }
244
245 #[test]
246 fn test_format_event() {
247 let event = make_event("GetObject", "s3", None);
248 let line = format_event(&event);
249 assert!(line.contains("s3"));
250 assert!(line.contains("GetObject"));
251 assert!(line.contains("R")); }
253
254 #[test]
255 fn test_format_event_with_error() {
256 let event = make_event("PutObject", "s3", Some("AccessDenied"));
257 let line = format_event(&event);
258 assert!(line.contains("AccessDenied"));
259 }
260
261 #[test]
262 fn test_format_summary() {
263 let mut state = WatchState::default();
264 state.merge(vec![
265 make_event("GetObject", "s3", None),
266 make_event("GetFunction", "lambda", None),
267 ]);
268 let session = crate::session::Session::new(
269 std::time::Duration::from_secs(900),
270 None,
271 crate::policy::ScopedPolicy::from_allow_str("s3:GetObject").unwrap(),
272 "arn:aws:iam::123456789012:role/Test".to_string(),
273 vec!["test".to_string()],
274 );
275 let summary = format_summary(&state, &session);
276 assert!(summary.contains("2 API calls"));
277 assert!(summary.contains("2 services"));
278 }
279}