Skip to main content

tryaudex_core/
watch.rs

1use std::collections::HashSet;
2
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5
6use crate::error::{AvError, Result};
7use crate::session::{Session, SessionStatus};
8
9/// A single API event observed during a session.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ApiEvent {
12    pub timestamp: DateTime<Utc>,
13    pub event_name: String,
14    pub service: String,
15    pub source_ip: Option<String>,
16    pub user_agent: Option<String>,
17    pub error_code: Option<String>,
18    pub error_message: Option<String>,
19    pub read_only: bool,
20    /// CloudTrail event ID. Used as the dedup key when available; falls back
21    /// to a "{timestamp_millis}:{event_name}" composite when absent.
22    pub event_id: Option<String>,
23}
24
25impl ApiEvent {
26    /// Stable dedup key: the CloudTrail event ID when present, otherwise a
27    /// composite of millisecond timestamp and event name.
28    fn dedup_key(&self) -> String {
29        match &self.event_id {
30            Some(id) => id.clone(),
31            None => format!("{}:{}", self.timestamp.timestamp_millis(), self.event_name),
32        }
33    }
34}
35
36/// Maximum number of events retained in memory to prevent unbounded growth.
37const MAX_EVENTS: usize = 10_000;
38
39/// Accumulated watch state for a session.
40#[derive(Debug, Default)]
41pub struct WatchState {
42    pub events: Vec<ApiEvent>,
43    /// Fast dedup lookup keyed by `ApiEvent::dedup_key()`.
44    seen: HashSet<String>,
45    pub last_poll: Option<DateTime<Utc>>,
46    pub total_api_calls: usize,
47    pub error_count: usize,
48    pub services_used: HashSet<String>,
49}
50
51impl WatchState {
52    /// Merge new events into the state, deduplicating by event ID (falling
53    /// back to timestamp+event_name). Caps stored events at `MAX_EVENTS`,
54    /// discarding oldest when full.
55    pub fn merge(&mut self, new_events: Vec<ApiEvent>) {
56        // R6-M12: track the newest CloudTrail event timestamp from this
57        // batch and advance `last_poll` to *just past* it, rather than
58        // anchoring to local `Utc::now()`. Under clock skew between the
59        // local machine and CloudTrail, `Utc::now()` could drift behind
60        // the newest event timestamp — each subsequent poll would then
61        // re-fetch the same events, dedup them, and observe zero new
62        // activity forever. Using `newest_event_ts + 1ms` guarantees
63        // monotonic forward progress regardless of clock drift.
64        let mut newest_ts: Option<DateTime<Utc>> = None;
65        for event in new_events {
66            let key = event.dedup_key();
67            if self.seen.contains(&key) {
68                continue;
69            }
70            if event.error_code.is_some() {
71                self.error_count += 1;
72            }
73            let svc = event.service.clone();
74            self.total_api_calls += 1;
75            self.services_used.insert(svc);
76            self.seen.insert(key);
77            newest_ts = Some(match newest_ts {
78                Some(t) if t >= event.timestamp => t,
79                _ => event.timestamp,
80            });
81            self.events.push(event);
82        }
83        // Evict oldest events if over cap
84        if self.events.len() > MAX_EVENTS {
85            let excess = self.events.len() - MAX_EVENTS;
86            for ev in self.events.drain(..excess) {
87                self.seen.remove(&ev.dedup_key());
88            }
89        }
90        let now = Utc::now();
91        self.last_poll = Some(match newest_ts {
92            Some(newest) => {
93                let advanced = newest + chrono::Duration::milliseconds(1);
94                if advanced > now {
95                    advanced
96                } else {
97                    now
98                }
99            }
100            None => now,
101        });
102    }
103}
104
105/// Poll CloudTrail LookupEvents for recent activity by an access key.
106/// Returns parsed API events since `start_time`.
107pub async fn poll_cloudtrail(
108    access_key_id: &str,
109    start_time: DateTime<Utc>,
110    region: Option<&str>,
111) -> Result<Vec<ApiEvent>> {
112    let mut loader = aws_config::defaults(aws_config::BehaviorVersion::latest());
113    if let Some(region) = region {
114        loader = loader.region(aws_config::Region::new(region.to_string()));
115    }
116    let config = loader.load().await;
117    let client = aws_sdk_cloudtrail::Client::new(&config);
118
119    let start = aws_sdk_cloudtrail::primitives::DateTime::from_secs(start_time.timestamp());
120
121    let lookup_attr = aws_sdk_cloudtrail::types::LookupAttribute::builder()
122        .attribute_key(aws_sdk_cloudtrail::types::LookupAttributeKey::AccessKeyId)
123        .attribute_value(access_key_id)
124        .build()
125        .map_err(|e| AvError::Sts(format!("CloudTrail attribute build error: {}", e)))?;
126
127    // Paginate to collect all events (max 500 to bound memory/API calls).
128    let mut all_raw_events = Vec::new();
129    let mut next_token: Option<String> = None;
130    const MAX_EVENTS: usize = 500;
131
132    loop {
133        let mut req = client
134            .lookup_events()
135            .lookup_attributes(lookup_attr.clone())
136            .start_time(start)
137            .max_results(50);
138
139        if let Some(ref token) = next_token {
140            req = req.next_token(token);
141        }
142
143        let result = req
144            .send()
145            .await
146            .map_err(|e| AvError::Sts(format!("CloudTrail LookupEvents error: {}", e)))?;
147
148        all_raw_events.extend(result.events().iter().cloned());
149
150        if all_raw_events.len() >= MAX_EVENTS {
151            break;
152        }
153
154        match result.next_token() {
155            Some(t) if !t.is_empty() => next_token = Some(t.to_string()),
156            _ => break,
157        }
158    }
159
160    let events = all_raw_events
161        .iter()
162        .filter_map(|e| {
163            let event_name = e.event_name()?.to_string();
164
165            // Parse service from event source using the shared CloudTrail mapping
166            let service = e
167                .event_source()
168                .and_then(|s| crate::learn::event_source_to_service(s).map(|svc| svc.to_string()))
169                .unwrap_or_else(|| "unknown".to_string());
170
171            let read_only = e
172                .read_only()
173                .and_then(|r| r.parse::<bool>().ok())
174                .unwrap_or(true);
175
176            // Parse error info from CloudTrail event JSON if available
177            let (error_code, error_message) = e
178                .cloud_trail_event()
179                .and_then(|json| serde_json::from_str::<serde_json::Value>(json).ok())
180                .map(|v| {
181                    let code = v
182                        .get("errorCode")
183                        .and_then(|c| c.as_str())
184                        .map(String::from);
185                    let msg = v
186                        .get("errorMessage")
187                        .and_then(|m| m.as_str())
188                        .map(String::from);
189                    (code, msg)
190                })
191                .unwrap_or((None, None));
192
193            let timestamp = e
194                .event_time()
195                .map(|t| {
196                    DateTime::from_timestamp(t.secs(), t.subsec_nanos()).unwrap_or_else(|| {
197                        tracing::warn!(
198                            secs = t.secs(),
199                            "CloudTrail event has unparseable timestamp; falling back to Utc::now()"
200                        );
201                        Utc::now()
202                    })
203                })
204                .unwrap_or_else(|| {
205                    tracing::warn!(
206                        "CloudTrail event missing event_time; falling back to Utc::now()"
207                    );
208                    Utc::now()
209                });
210
211            let source_ip = e
212                .cloud_trail_event()
213                .and_then(|json| serde_json::from_str::<serde_json::Value>(json).ok())
214                .and_then(|v| {
215                    v.get("sourceIPAddress")
216                        .and_then(|s| s.as_str())
217                        .map(String::from)
218                });
219
220            let user_agent = e
221                .cloud_trail_event()
222                .and_then(|json| serde_json::from_str::<serde_json::Value>(json).ok())
223                .and_then(|v| {
224                    v.get("userAgent")
225                        .and_then(|s| s.as_str())
226                        .map(String::from)
227                });
228
229            let event_id = e.event_id().map(String::from);
230
231            Some(ApiEvent {
232                timestamp,
233                event_name,
234                service,
235                source_ip,
236                user_agent,
237                error_code,
238                error_message,
239                read_only,
240                event_id,
241            })
242        })
243        .collect();
244
245    Ok(events)
246}
247
248/// Check if a session is still watchable (active and not expired).
249pub fn is_watchable(session: &Session) -> bool {
250    session.status == SessionStatus::Active && !session.is_expired()
251}
252
253/// Format an event as a colored terminal line.
254pub fn format_event(event: &ApiEvent) -> String {
255    let ts = event.timestamp.format("%H:%M:%S");
256    let rw = if event.read_only { "R" } else { "W" };
257    let status = if let Some(ref code) = event.error_code {
258        format!(" [ERR: {}]", code)
259    } else {
260        String::new()
261    };
262    format!(
263        "[{}] {} {:<12} {}{}",
264        ts, rw, event.service, event.event_name, status
265    )
266}
267
268/// Format a summary of the current watch state.
269pub fn format_summary(state: &WatchState, session: &Session) -> String {
270    let mut out = String::new();
271    out.push_str(&format!(
272        "Session {} | {} API calls | {} errors | {} services | {}s remaining\n",
273        session.short_id(),
274        state.total_api_calls,
275        state.error_count,
276        state.services_used.len(),
277        session.remaining_seconds()
278    ));
279    if !state.services_used.is_empty() {
280        let mut svcs: Vec<&str> = state.services_used.iter().map(String::as_str).collect();
281        svcs.sort_unstable();
282        out.push_str(&format!("Services: {}\n", svcs.join(", ")));
283    }
284    out
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    fn make_event(name: &str, service: &str, error: Option<&str>) -> ApiEvent {
292        ApiEvent {
293            timestamp: Utc::now(),
294            event_name: name.to_string(),
295            service: service.to_string(),
296            source_ip: Some("10.0.0.1".to_string()),
297            user_agent: Some("aws-cli/2.0".to_string()),
298            error_code: error.map(String::from),
299            error_message: None,
300            read_only: true,
301            event_id: None,
302        }
303    }
304
305    #[test]
306    fn test_watch_state_merge() {
307        let mut state = WatchState::default();
308        let events = vec![
309            make_event("GetObject", "s3", None),
310            make_event("ListBuckets", "s3", None),
311        ];
312        state.merge(events);
313        assert_eq!(state.total_api_calls, 2);
314        assert_eq!(state.error_count, 0);
315        assert!(state.services_used.contains("s3"));
316        assert_eq!(state.services_used.len(), 1);
317    }
318
319    #[test]
320    fn test_watch_state_dedup() {
321        let mut state = WatchState::default();
322        let ts = Utc::now();
323        let event = ApiEvent {
324            timestamp: ts,
325            event_name: "GetObject".to_string(),
326            service: "s3".to_string(),
327            source_ip: None,
328            user_agent: None,
329            error_code: None,
330            error_message: None,
331            read_only: true,
332            event_id: None,
333        };
334        state.merge(vec![event.clone()]);
335        state.merge(vec![event]);
336        assert_eq!(state.total_api_calls, 1); // deduped
337    }
338
339    #[test]
340    fn test_watch_state_error_count() {
341        let mut state = WatchState::default();
342        let events = vec![
343            make_event("GetObject", "s3", None),
344            make_event("PutObject", "s3", Some("AccessDenied")),
345        ];
346        state.merge(events);
347        assert_eq!(state.error_count, 1);
348    }
349
350    #[test]
351    fn test_watch_state_multi_service() {
352        let mut state = WatchState::default();
353        let events = vec![
354            make_event("GetObject", "s3", None),
355            make_event("GetFunction", "lambda", None),
356            make_event("ListBuckets", "s3", None),
357        ];
358        state.merge(events);
359        assert_eq!(state.services_used.len(), 2);
360        assert!(state.services_used.contains("s3"));
361        assert!(state.services_used.contains("lambda"));
362    }
363
364    #[test]
365    fn test_format_event() {
366        let event = make_event("GetObject", "s3", None);
367        let line = format_event(&event);
368        assert!(line.contains("s3"));
369        assert!(line.contains("GetObject"));
370        assert!(line.contains("R")); // read-only
371    }
372
373    #[test]
374    fn test_format_event_with_error() {
375        let event = make_event("PutObject", "s3", Some("AccessDenied"));
376        let line = format_event(&event);
377        assert!(line.contains("AccessDenied"));
378    }
379
380    #[test]
381    fn test_format_summary() {
382        let mut state = WatchState::default();
383        state.merge(vec![
384            make_event("GetObject", "s3", None),
385            make_event("GetFunction", "lambda", None),
386        ]);
387        let session = crate::session::Session::new(
388            std::time::Duration::from_secs(900),
389            None,
390            crate::policy::ScopedPolicy::from_allow_str("s3:GetObject").unwrap(),
391            "arn:aws:iam::123456789012:role/Test".to_string(),
392            vec!["test".to_string()],
393        );
394        let summary = format_summary(&state, &session);
395        assert!(summary.contains("2 API calls"));
396        assert!(summary.contains("2 services"));
397    }
398}