Skip to main content

tryaudex_core/
watch.rs

1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3
4use crate::error::{AvError, Result};
5use crate::session::{Session, SessionStatus};
6
7/// A single API event observed during a session.
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct ApiEvent {
10    pub timestamp: DateTime<Utc>,
11    pub event_name: String,
12    pub service: String,
13    pub source_ip: Option<String>,
14    pub user_agent: Option<String>,
15    pub error_code: Option<String>,
16    pub error_message: Option<String>,
17    pub read_only: bool,
18}
19
20/// Accumulated watch state for a session.
21#[derive(Debug, Default)]
22pub struct WatchState {
23    pub events: Vec<ApiEvent>,
24    pub last_poll: Option<DateTime<Utc>>,
25    pub total_api_calls: usize,
26    pub error_count: usize,
27    pub services_used: Vec<String>,
28}
29
30impl WatchState {
31    /// Merge new events into the state, deduplicating by timestamp+event_name.
32    pub fn merge(&mut self, new_events: Vec<ApiEvent>) {
33        for event in new_events {
34            let is_dup = self
35                .events
36                .iter()
37                .any(|e| e.timestamp == event.timestamp && e.event_name == event.event_name);
38            if !is_dup {
39                if event.error_code.is_some() {
40                    self.error_count += 1;
41                }
42                let svc = event.service.clone();
43                self.events.push(event);
44                self.total_api_calls += 1;
45                if !self.services_used.contains(&svc) {
46                    self.services_used.push(svc);
47                }
48            }
49        }
50        self.last_poll = Some(Utc::now());
51    }
52}
53
54/// Poll CloudTrail LookupEvents for recent activity by an access key.
55/// Returns parsed API events since `start_time`.
56pub async fn poll_cloudtrail(
57    access_key_id: &str,
58    start_time: DateTime<Utc>,
59    region: Option<&str>,
60) -> Result<Vec<ApiEvent>> {
61    let mut loader = aws_config::defaults(aws_config::BehaviorVersion::latest());
62    if let Some(region) = region {
63        loader = loader.region(aws_config::Region::new(region.to_string()));
64    }
65    let config = loader.load().await;
66    let client = aws_sdk_cloudtrail::Client::new(&config);
67
68    let start = aws_sdk_cloudtrail::primitives::DateTime::from_secs(start_time.timestamp());
69
70    let result = client
71        .lookup_events()
72        .lookup_attributes(
73            aws_sdk_cloudtrail::types::LookupAttribute::builder()
74                .attribute_key(aws_sdk_cloudtrail::types::LookupAttributeKey::AccessKeyId)
75                .attribute_value(access_key_id)
76                .build()
77                .map_err(|e| AvError::Sts(format!("CloudTrail attribute build error: {}", e)))?,
78        )
79        .start_time(start)
80        .max_results(50)
81        .send()
82        .await
83        .map_err(|e| AvError::Sts(format!("CloudTrail LookupEvents error: {}", e)))?;
84
85    let events = result
86        .events()
87        .iter()
88        .filter_map(|e| {
89            let event_name = e.event_name()?.to_string();
90
91            // Parse service from event source (e.g. "s3.amazonaws.com" -> "s3")
92            let service = e
93                .event_source()
94                .map(|s| s.split('.').next().unwrap_or(s).to_string())
95                .unwrap_or_else(|| "unknown".to_string());
96
97            let read_only = e
98                .read_only()
99                .and_then(|r| r.parse::<bool>().ok())
100                .unwrap_or(true);
101
102            // Parse error info from CloudTrail event JSON if available
103            let (error_code, error_message) = e
104                .cloud_trail_event()
105                .and_then(|json| serde_json::from_str::<serde_json::Value>(json).ok())
106                .map(|v| {
107                    let code = v
108                        .get("errorCode")
109                        .and_then(|c| c.as_str())
110                        .map(String::from);
111                    let msg = v
112                        .get("errorMessage")
113                        .and_then(|m| m.as_str())
114                        .map(String::from);
115                    (code, msg)
116                })
117                .unwrap_or((None, None));
118
119            let timestamp = e
120                .event_time()
121                .map(|t| {
122                    DateTime::from_timestamp(t.secs(), t.subsec_nanos()).unwrap_or_else(Utc::now)
123                })
124                .unwrap_or_else(Utc::now);
125
126            let source_ip = e
127                .cloud_trail_event()
128                .and_then(|json| serde_json::from_str::<serde_json::Value>(json).ok())
129                .and_then(|v| {
130                    v.get("sourceIPAddress")
131                        .and_then(|s| s.as_str())
132                        .map(String::from)
133                });
134
135            let user_agent = e
136                .cloud_trail_event()
137                .and_then(|json| serde_json::from_str::<serde_json::Value>(json).ok())
138                .and_then(|v| {
139                    v.get("userAgent")
140                        .and_then(|s| s.as_str())
141                        .map(String::from)
142                });
143
144            Some(ApiEvent {
145                timestamp,
146                event_name,
147                service,
148                source_ip,
149                user_agent,
150                error_code,
151                error_message,
152                read_only,
153            })
154        })
155        .collect();
156
157    Ok(events)
158}
159
160/// Check if a session is still watchable (active and not expired).
161pub fn is_watchable(session: &Session) -> bool {
162    session.status == SessionStatus::Active && !session.is_expired()
163}
164
165/// Format an event as a colored terminal line.
166pub fn format_event(event: &ApiEvent) -> String {
167    let ts = event.timestamp.format("%H:%M:%S");
168    let rw = if event.read_only { "R" } else { "W" };
169    let status = if let Some(ref code) = event.error_code {
170        format!(" [ERR: {}]", code)
171    } else {
172        String::new()
173    };
174    format!(
175        "[{}] {} {:<12} {}{}",
176        ts, rw, event.service, event.event_name, status
177    )
178}
179
180/// Format a summary of the current watch state.
181pub fn format_summary(state: &WatchState, session: &Session) -> String {
182    let mut out = String::new();
183    out.push_str(&format!(
184        "Session {} | {} API calls | {} errors | {} services | {}s remaining\n",
185        &session.id[..8],
186        state.total_api_calls,
187        state.error_count,
188        state.services_used.len(),
189        session.remaining_seconds()
190    ));
191    if !state.services_used.is_empty() {
192        out.push_str(&format!("Services: {}\n", state.services_used.join(", ")));
193    }
194    out
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    fn make_event(name: &str, service: &str, error: Option<&str>) -> ApiEvent {
202        ApiEvent {
203            timestamp: Utc::now(),
204            event_name: name.to_string(),
205            service: service.to_string(),
206            source_ip: Some("10.0.0.1".to_string()),
207            user_agent: Some("aws-cli/2.0".to_string()),
208            error_code: error.map(String::from),
209            error_message: None,
210            read_only: true,
211        }
212    }
213
214    #[test]
215    fn test_watch_state_merge() {
216        let mut state = WatchState::default();
217        let events = vec![
218            make_event("GetObject", "s3", None),
219            make_event("ListBuckets", "s3", None),
220        ];
221        state.merge(events);
222        assert_eq!(state.total_api_calls, 2);
223        assert_eq!(state.error_count, 0);
224        assert_eq!(state.services_used, vec!["s3"]);
225    }
226
227    #[test]
228    fn test_watch_state_dedup() {
229        let mut state = WatchState::default();
230        let ts = Utc::now();
231        let event = ApiEvent {
232            timestamp: ts,
233            event_name: "GetObject".to_string(),
234            service: "s3".to_string(),
235            source_ip: None,
236            user_agent: None,
237            error_code: None,
238            error_message: None,
239            read_only: true,
240        };
241        state.merge(vec![event.clone()]);
242        state.merge(vec![event]);
243        assert_eq!(state.total_api_calls, 1); // deduped
244    }
245
246    #[test]
247    fn test_watch_state_error_count() {
248        let mut state = WatchState::default();
249        let events = vec![
250            make_event("GetObject", "s3", None),
251            make_event("PutObject", "s3", Some("AccessDenied")),
252        ];
253        state.merge(events);
254        assert_eq!(state.error_count, 1);
255    }
256
257    #[test]
258    fn test_watch_state_multi_service() {
259        let mut state = WatchState::default();
260        let events = vec![
261            make_event("GetObject", "s3", None),
262            make_event("GetFunction", "lambda", None),
263            make_event("ListBuckets", "s3", None),
264        ];
265        state.merge(events);
266        assert_eq!(state.services_used.len(), 2);
267        assert!(state.services_used.contains(&"s3".to_string()));
268        assert!(state.services_used.contains(&"lambda".to_string()));
269    }
270
271    #[test]
272    fn test_format_event() {
273        let event = make_event("GetObject", "s3", None);
274        let line = format_event(&event);
275        assert!(line.contains("s3"));
276        assert!(line.contains("GetObject"));
277        assert!(line.contains("R")); // read-only
278    }
279
280    #[test]
281    fn test_format_event_with_error() {
282        let event = make_event("PutObject", "s3", Some("AccessDenied"));
283        let line = format_event(&event);
284        assert!(line.contains("AccessDenied"));
285    }
286
287    #[test]
288    fn test_format_summary() {
289        let mut state = WatchState::default();
290        state.merge(vec![
291            make_event("GetObject", "s3", None),
292            make_event("GetFunction", "lambda", None),
293        ]);
294        let session = crate::session::Session::new(
295            std::time::Duration::from_secs(900),
296            None,
297            crate::policy::ScopedPolicy::from_allow_str("s3:GetObject").unwrap(),
298            "arn:aws:iam::123456789012:role/Test".to_string(),
299            vec!["test".to_string()],
300        );
301        let summary = format_summary(&state, &session);
302        assert!(summary.contains("2 API calls"));
303        assert!(summary.contains("2 services"));
304    }
305}