1use std::collections::HashSet;
2
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5
6use crate::error::{AvError, Result};
7use crate::session::{Session, SessionStatus};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ApiEvent {
12 pub timestamp: DateTime<Utc>,
13 pub event_name: String,
14 pub service: String,
15 pub source_ip: Option<String>,
16 pub user_agent: Option<String>,
17 pub error_code: Option<String>,
18 pub error_message: Option<String>,
19 pub read_only: bool,
20 pub event_id: Option<String>,
23}
24
25impl ApiEvent {
26 fn dedup_key(&self) -> String {
29 match &self.event_id {
30 Some(id) => id.clone(),
31 None => format!("{}:{}", self.timestamp.timestamp_millis(), self.event_name),
32 }
33 }
34}
35
36const MAX_EVENTS: usize = 10_000;
38
39#[derive(Debug, Default)]
41pub struct WatchState {
42 pub events: Vec<ApiEvent>,
43 seen: HashSet<String>,
45 pub last_poll: Option<DateTime<Utc>>,
46 pub total_api_calls: usize,
47 pub error_count: usize,
48 pub services_used: HashSet<String>,
49}
50
51impl WatchState {
52 pub fn merge(&mut self, new_events: Vec<ApiEvent>) {
56 let mut newest_ts: Option<DateTime<Utc>> = None;
65 for event in new_events {
66 let key = event.dedup_key();
67 if self.seen.contains(&key) {
68 continue;
69 }
70 if event.error_code.is_some() {
71 self.error_count += 1;
72 }
73 let svc = event.service.clone();
74 self.total_api_calls += 1;
75 self.services_used.insert(svc);
76 self.seen.insert(key);
77 newest_ts = Some(match newest_ts {
78 Some(t) if t >= event.timestamp => t,
79 _ => event.timestamp,
80 });
81 self.events.push(event);
82 }
83 if self.events.len() > MAX_EVENTS {
85 let excess = self.events.len() - MAX_EVENTS;
86 for ev in self.events.drain(..excess) {
87 self.seen.remove(&ev.dedup_key());
88 }
89 }
90 let now = Utc::now();
91 self.last_poll = Some(match newest_ts {
92 Some(newest) => {
93 let advanced = newest + chrono::Duration::milliseconds(1);
94 if advanced > now {
95 advanced
96 } else {
97 now
98 }
99 }
100 None => now,
101 });
102 }
103}
104
105pub async fn poll_cloudtrail(
108 access_key_id: &str,
109 start_time: DateTime<Utc>,
110 region: Option<&str>,
111) -> Result<Vec<ApiEvent>> {
112 let mut loader = aws_config::defaults(aws_config::BehaviorVersion::latest());
113 if let Some(region) = region {
114 loader = loader.region(aws_config::Region::new(region.to_string()));
115 }
116 let config = loader.load().await;
117 let client = aws_sdk_cloudtrail::Client::new(&config);
118
119 let start = aws_sdk_cloudtrail::primitives::DateTime::from_secs(start_time.timestamp());
120
121 let lookup_attr = aws_sdk_cloudtrail::types::LookupAttribute::builder()
122 .attribute_key(aws_sdk_cloudtrail::types::LookupAttributeKey::AccessKeyId)
123 .attribute_value(access_key_id)
124 .build()
125 .map_err(|e| AvError::Sts(format!("CloudTrail attribute build error: {}", e)))?;
126
127 let mut all_raw_events = Vec::new();
129 let mut next_token: Option<String> = None;
130 const MAX_EVENTS: usize = 500;
131
132 loop {
133 let mut req = client
134 .lookup_events()
135 .lookup_attributes(lookup_attr.clone())
136 .start_time(start)
137 .max_results(50);
138
139 if let Some(ref token) = next_token {
140 req = req.next_token(token);
141 }
142
143 let result = req
144 .send()
145 .await
146 .map_err(|e| AvError::Sts(format!("CloudTrail LookupEvents error: {}", e)))?;
147
148 all_raw_events.extend(result.events().iter().cloned());
149
150 if all_raw_events.len() >= MAX_EVENTS {
151 break;
152 }
153
154 match result.next_token() {
155 Some(t) if !t.is_empty() => next_token = Some(t.to_string()),
156 _ => break,
157 }
158 }
159
160 let events = all_raw_events
161 .iter()
162 .filter_map(|e| {
163 let event_name = e.event_name()?.to_string();
164
165 let service = e
167 .event_source()
168 .and_then(|s| crate::learn::event_source_to_service(s).map(|svc| svc.to_string()))
169 .unwrap_or_else(|| "unknown".to_string());
170
171 let read_only = e
172 .read_only()
173 .and_then(|r| r.parse::<bool>().ok())
174 .unwrap_or(true);
175
176 let (error_code, error_message) = e
178 .cloud_trail_event()
179 .and_then(|json| serde_json::from_str::<serde_json::Value>(json).ok())
180 .map(|v| {
181 let code = v
182 .get("errorCode")
183 .and_then(|c| c.as_str())
184 .map(String::from);
185 let msg = v
186 .get("errorMessage")
187 .and_then(|m| m.as_str())
188 .map(String::from);
189 (code, msg)
190 })
191 .unwrap_or((None, None));
192
193 let timestamp = e
194 .event_time()
195 .map(|t| {
196 DateTime::from_timestamp(t.secs(), t.subsec_nanos()).unwrap_or_else(|| {
197 tracing::warn!(
198 secs = t.secs(),
199 "CloudTrail event has unparseable timestamp; falling back to Utc::now()"
200 );
201 Utc::now()
202 })
203 })
204 .unwrap_or_else(|| {
205 tracing::warn!(
206 "CloudTrail event missing event_time; falling back to Utc::now()"
207 );
208 Utc::now()
209 });
210
211 let source_ip = e
212 .cloud_trail_event()
213 .and_then(|json| serde_json::from_str::<serde_json::Value>(json).ok())
214 .and_then(|v| {
215 v.get("sourceIPAddress")
216 .and_then(|s| s.as_str())
217 .map(String::from)
218 });
219
220 let user_agent = e
221 .cloud_trail_event()
222 .and_then(|json| serde_json::from_str::<serde_json::Value>(json).ok())
223 .and_then(|v| {
224 v.get("userAgent")
225 .and_then(|s| s.as_str())
226 .map(String::from)
227 });
228
229 let event_id = e.event_id().map(String::from);
230
231 Some(ApiEvent {
232 timestamp,
233 event_name,
234 service,
235 source_ip,
236 user_agent,
237 error_code,
238 error_message,
239 read_only,
240 event_id,
241 })
242 })
243 .collect();
244
245 Ok(events)
246}
247
248pub fn is_watchable(session: &Session) -> bool {
250 session.status == SessionStatus::Active && !session.is_expired()
251}
252
253pub fn format_event(event: &ApiEvent) -> String {
255 let ts = event.timestamp.format("%H:%M:%S");
256 let rw = if event.read_only { "R" } else { "W" };
257 let status = if let Some(ref code) = event.error_code {
258 format!(" [ERR: {}]", code)
259 } else {
260 String::new()
261 };
262 format!(
263 "[{}] {} {:<12} {}{}",
264 ts, rw, event.service, event.event_name, status
265 )
266}
267
268pub fn format_summary(state: &WatchState, session: &Session) -> String {
270 let mut out = String::new();
271 out.push_str(&format!(
272 "Session {} | {} API calls | {} errors | {} services | {}s remaining\n",
273 session.short_id(),
274 state.total_api_calls,
275 state.error_count,
276 state.services_used.len(),
277 session.remaining_seconds()
278 ));
279 if !state.services_used.is_empty() {
280 let mut svcs: Vec<&str> = state.services_used.iter().map(String::as_str).collect();
281 svcs.sort_unstable();
282 out.push_str(&format!("Services: {}\n", svcs.join(", ")));
283 }
284 out
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 fn make_event(name: &str, service: &str, error: Option<&str>) -> ApiEvent {
292 ApiEvent {
293 timestamp: Utc::now(),
294 event_name: name.to_string(),
295 service: service.to_string(),
296 source_ip: Some("10.0.0.1".to_string()),
297 user_agent: Some("aws-cli/2.0".to_string()),
298 error_code: error.map(String::from),
299 error_message: None,
300 read_only: true,
301 event_id: None,
302 }
303 }
304
305 #[test]
306 fn test_watch_state_merge() {
307 let mut state = WatchState::default();
308 let events = vec![
309 make_event("GetObject", "s3", None),
310 make_event("ListBuckets", "s3", None),
311 ];
312 state.merge(events);
313 assert_eq!(state.total_api_calls, 2);
314 assert_eq!(state.error_count, 0);
315 assert!(state.services_used.contains("s3"));
316 assert_eq!(state.services_used.len(), 1);
317 }
318
319 #[test]
320 fn test_watch_state_dedup() {
321 let mut state = WatchState::default();
322 let ts = Utc::now();
323 let event = ApiEvent {
324 timestamp: ts,
325 event_name: "GetObject".to_string(),
326 service: "s3".to_string(),
327 source_ip: None,
328 user_agent: None,
329 error_code: None,
330 error_message: None,
331 read_only: true,
332 event_id: None,
333 };
334 state.merge(vec![event.clone()]);
335 state.merge(vec![event]);
336 assert_eq!(state.total_api_calls, 1); }
338
339 #[test]
340 fn test_watch_state_error_count() {
341 let mut state = WatchState::default();
342 let events = vec![
343 make_event("GetObject", "s3", None),
344 make_event("PutObject", "s3", Some("AccessDenied")),
345 ];
346 state.merge(events);
347 assert_eq!(state.error_count, 1);
348 }
349
350 #[test]
351 fn test_watch_state_multi_service() {
352 let mut state = WatchState::default();
353 let events = vec![
354 make_event("GetObject", "s3", None),
355 make_event("GetFunction", "lambda", None),
356 make_event("ListBuckets", "s3", None),
357 ];
358 state.merge(events);
359 assert_eq!(state.services_used.len(), 2);
360 assert!(state.services_used.contains("s3"));
361 assert!(state.services_used.contains("lambda"));
362 }
363
364 #[test]
365 fn test_format_event() {
366 let event = make_event("GetObject", "s3", None);
367 let line = format_event(&event);
368 assert!(line.contains("s3"));
369 assert!(line.contains("GetObject"));
370 assert!(line.contains("R")); }
372
373 #[test]
374 fn test_format_event_with_error() {
375 let event = make_event("PutObject", "s3", Some("AccessDenied"));
376 let line = format_event(&event);
377 assert!(line.contains("AccessDenied"));
378 }
379
380 #[test]
381 fn test_format_summary() {
382 let mut state = WatchState::default();
383 state.merge(vec![
384 make_event("GetObject", "s3", None),
385 make_event("GetFunction", "lambda", None),
386 ]);
387 let session = crate::session::Session::new(
388 std::time::Duration::from_secs(900),
389 None,
390 crate::policy::ScopedPolicy::from_allow_str("s3:GetObject").unwrap(),
391 "arn:aws:iam::123456789012:role/Test".to_string(),
392 vec!["test".to_string()],
393 );
394 let summary = format_summary(&state, &session);
395 assert!(summary.contains("2 API calls"));
396 assert!(summary.contains("2 services"));
397 }
398}