1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3
4use crate::error::{AvError, Result};
5use crate::session::{Session, SessionStatus};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct ApiEvent {
10 pub timestamp: DateTime<Utc>,
11 pub event_name: String,
12 pub service: String,
13 pub source_ip: Option<String>,
14 pub user_agent: Option<String>,
15 pub error_code: Option<String>,
16 pub error_message: Option<String>,
17 pub read_only: bool,
18}
19
20#[derive(Debug, Default)]
22pub struct WatchState {
23 pub events: Vec<ApiEvent>,
24 pub last_poll: Option<DateTime<Utc>>,
25 pub total_api_calls: usize,
26 pub error_count: usize,
27 pub services_used: Vec<String>,
28}
29
30impl WatchState {
31 pub fn merge(&mut self, new_events: Vec<ApiEvent>) {
33 for event in new_events {
34 let is_dup = self
35 .events
36 .iter()
37 .any(|e| e.timestamp == event.timestamp && e.event_name == event.event_name);
38 if !is_dup {
39 if event.error_code.is_some() {
40 self.error_count += 1;
41 }
42 let svc = event.service.clone();
43 self.events.push(event);
44 self.total_api_calls += 1;
45 if !self.services_used.contains(&svc) {
46 self.services_used.push(svc);
47 }
48 }
49 }
50 self.last_poll = Some(Utc::now());
51 }
52}
53
54pub async fn poll_cloudtrail(
57 access_key_id: &str,
58 start_time: DateTime<Utc>,
59 region: Option<&str>,
60) -> Result<Vec<ApiEvent>> {
61 let mut loader = aws_config::defaults(aws_config::BehaviorVersion::latest());
62 if let Some(region) = region {
63 loader = loader.region(aws_config::Region::new(region.to_string()));
64 }
65 let config = loader.load().await;
66 let client = aws_sdk_cloudtrail::Client::new(&config);
67
68 let start = aws_sdk_cloudtrail::primitives::DateTime::from_secs(start_time.timestamp());
69
70 let result = client
71 .lookup_events()
72 .lookup_attributes(
73 aws_sdk_cloudtrail::types::LookupAttribute::builder()
74 .attribute_key(aws_sdk_cloudtrail::types::LookupAttributeKey::AccessKeyId)
75 .attribute_value(access_key_id)
76 .build()
77 .map_err(|e| AvError::Sts(format!("CloudTrail attribute build error: {}", e)))?,
78 )
79 .start_time(start)
80 .max_results(50)
81 .send()
82 .await
83 .map_err(|e| AvError::Sts(format!("CloudTrail LookupEvents error: {}", e)))?;
84
85 let events = result
86 .events()
87 .iter()
88 .filter_map(|e| {
89 let event_name = e.event_name()?.to_string();
90
91 let service = e
93 .event_source()
94 .map(|s| s.split('.').next().unwrap_or(s).to_string())
95 .unwrap_or_else(|| "unknown".to_string());
96
97 let read_only = e
98 .read_only()
99 .and_then(|r| r.parse::<bool>().ok())
100 .unwrap_or(true);
101
102 let (error_code, error_message) = e
104 .cloud_trail_event()
105 .and_then(|json| serde_json::from_str::<serde_json::Value>(json).ok())
106 .map(|v| {
107 let code = v
108 .get("errorCode")
109 .and_then(|c| c.as_str())
110 .map(String::from);
111 let msg = v
112 .get("errorMessage")
113 .and_then(|m| m.as_str())
114 .map(String::from);
115 (code, msg)
116 })
117 .unwrap_or((None, None));
118
119 let timestamp = e
120 .event_time()
121 .map(|t| {
122 DateTime::from_timestamp(t.secs(), t.subsec_nanos()).unwrap_or_else(Utc::now)
123 })
124 .unwrap_or_else(Utc::now);
125
126 let source_ip = e
127 .cloud_trail_event()
128 .and_then(|json| serde_json::from_str::<serde_json::Value>(json).ok())
129 .and_then(|v| {
130 v.get("sourceIPAddress")
131 .and_then(|s| s.as_str())
132 .map(String::from)
133 });
134
135 let user_agent = e
136 .cloud_trail_event()
137 .and_then(|json| serde_json::from_str::<serde_json::Value>(json).ok())
138 .and_then(|v| {
139 v.get("userAgent")
140 .and_then(|s| s.as_str())
141 .map(String::from)
142 });
143
144 Some(ApiEvent {
145 timestamp,
146 event_name,
147 service,
148 source_ip,
149 user_agent,
150 error_code,
151 error_message,
152 read_only,
153 })
154 })
155 .collect();
156
157 Ok(events)
158}
159
160pub fn is_watchable(session: &Session) -> bool {
162 session.status == SessionStatus::Active && !session.is_expired()
163}
164
165pub fn format_event(event: &ApiEvent) -> String {
167 let ts = event.timestamp.format("%H:%M:%S");
168 let rw = if event.read_only { "R" } else { "W" };
169 let status = if let Some(ref code) = event.error_code {
170 format!(" [ERR: {}]", code)
171 } else {
172 String::new()
173 };
174 format!(
175 "[{}] {} {:<12} {}{}",
176 ts, rw, event.service, event.event_name, status
177 )
178}
179
180pub fn format_summary(state: &WatchState, session: &Session) -> String {
182 let mut out = String::new();
183 out.push_str(&format!(
184 "Session {} | {} API calls | {} errors | {} services | {}s remaining\n",
185 &session.id[..8],
186 state.total_api_calls,
187 state.error_count,
188 state.services_used.len(),
189 session.remaining_seconds()
190 ));
191 if !state.services_used.is_empty() {
192 out.push_str(&format!("Services: {}\n", state.services_used.join(", ")));
193 }
194 out
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 fn make_event(name: &str, service: &str, error: Option<&str>) -> ApiEvent {
202 ApiEvent {
203 timestamp: Utc::now(),
204 event_name: name.to_string(),
205 service: service.to_string(),
206 source_ip: Some("10.0.0.1".to_string()),
207 user_agent: Some("aws-cli/2.0".to_string()),
208 error_code: error.map(String::from),
209 error_message: None,
210 read_only: true,
211 }
212 }
213
214 #[test]
215 fn test_watch_state_merge() {
216 let mut state = WatchState::default();
217 let events = vec![
218 make_event("GetObject", "s3", None),
219 make_event("ListBuckets", "s3", None),
220 ];
221 state.merge(events);
222 assert_eq!(state.total_api_calls, 2);
223 assert_eq!(state.error_count, 0);
224 assert_eq!(state.services_used, vec!["s3"]);
225 }
226
227 #[test]
228 fn test_watch_state_dedup() {
229 let mut state = WatchState::default();
230 let ts = Utc::now();
231 let event = ApiEvent {
232 timestamp: ts,
233 event_name: "GetObject".to_string(),
234 service: "s3".to_string(),
235 source_ip: None,
236 user_agent: None,
237 error_code: None,
238 error_message: None,
239 read_only: true,
240 };
241 state.merge(vec![event.clone()]);
242 state.merge(vec![event]);
243 assert_eq!(state.total_api_calls, 1); }
245
246 #[test]
247 fn test_watch_state_error_count() {
248 let mut state = WatchState::default();
249 let events = vec![
250 make_event("GetObject", "s3", None),
251 make_event("PutObject", "s3", Some("AccessDenied")),
252 ];
253 state.merge(events);
254 assert_eq!(state.error_count, 1);
255 }
256
257 #[test]
258 fn test_watch_state_multi_service() {
259 let mut state = WatchState::default();
260 let events = vec![
261 make_event("GetObject", "s3", None),
262 make_event("GetFunction", "lambda", None),
263 make_event("ListBuckets", "s3", None),
264 ];
265 state.merge(events);
266 assert_eq!(state.services_used.len(), 2);
267 assert!(state.services_used.contains(&"s3".to_string()));
268 assert!(state.services_used.contains(&"lambda".to_string()));
269 }
270
271 #[test]
272 fn test_format_event() {
273 let event = make_event("GetObject", "s3", None);
274 let line = format_event(&event);
275 assert!(line.contains("s3"));
276 assert!(line.contains("GetObject"));
277 assert!(line.contains("R")); }
279
280 #[test]
281 fn test_format_event_with_error() {
282 let event = make_event("PutObject", "s3", Some("AccessDenied"));
283 let line = format_event(&event);
284 assert!(line.contains("AccessDenied"));
285 }
286
287 #[test]
288 fn test_format_summary() {
289 let mut state = WatchState::default();
290 state.merge(vec![
291 make_event("GetObject", "s3", None),
292 make_event("GetFunction", "lambda", None),
293 ]);
294 let session = crate::session::Session::new(
295 std::time::Duration::from_secs(900),
296 None,
297 crate::policy::ScopedPolicy::from_allow_str("s3:GetObject").unwrap(),
298 "arn:aws:iam::123456789012:role/Test".to_string(),
299 vec!["test".to_string()],
300 );
301 let summary = format_summary(&state, &session);
302 assert!(summary.contains("2 API calls"));
303 assert!(summary.contains("2 services"));
304 }
305}