Skip to main content

tryaudex_core/
watch.rs

1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3
4use crate::error::{AvError, Result};
5use crate::session::{Session, SessionStatus};
6
7/// A single API event observed during a session.
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct ApiEvent {
10    pub timestamp: DateTime<Utc>,
11    pub event_name: String,
12    pub service: String,
13    pub source_ip: Option<String>,
14    pub user_agent: Option<String>,
15    pub error_code: Option<String>,
16    pub error_message: Option<String>,
17    pub read_only: bool,
18}
19
20/// Accumulated watch state for a session.
21#[derive(Debug, Default)]
22pub struct WatchState {
23    pub events: Vec<ApiEvent>,
24    pub last_poll: Option<DateTime<Utc>>,
25    pub total_api_calls: usize,
26    pub error_count: usize,
27    pub services_used: Vec<String>,
28}
29
30impl WatchState {
31    /// Merge new events into the state, deduplicating by timestamp+event_name.
32    pub fn merge(&mut self, new_events: Vec<ApiEvent>) {
33        for event in new_events {
34            let is_dup = self.events.iter().any(|e| {
35                e.timestamp == event.timestamp && e.event_name == event.event_name
36            });
37            if !is_dup {
38                if event.error_code.is_some() {
39                    self.error_count += 1;
40                }
41                let svc = event.service.clone();
42                self.events.push(event);
43                self.total_api_calls += 1;
44                if !self.services_used.contains(&svc) {
45                    self.services_used.push(svc);
46                }
47            }
48        }
49        self.last_poll = Some(Utc::now());
50    }
51}
52
53/// Poll CloudTrail LookupEvents for recent activity by an access key.
54/// Returns parsed API events since `start_time`.
55pub async fn poll_cloudtrail(
56    access_key_id: &str,
57    start_time: DateTime<Utc>,
58    region: Option<&str>,
59) -> Result<Vec<ApiEvent>> {
60    let mut loader = aws_config::defaults(aws_config::BehaviorVersion::latest());
61    if let Some(region) = region {
62        loader = loader.region(aws_config::Region::new(region.to_string()));
63    }
64    let config = loader.load().await;
65    let client = aws_sdk_cloudtrail::Client::new(&config);
66
67    let start = aws_sdk_cloudtrail::primitives::DateTime::from_secs(start_time.timestamp());
68
69    let result = client
70        .lookup_events()
71        .lookup_attributes(
72            aws_sdk_cloudtrail::types::LookupAttribute::builder()
73                .attribute_key(aws_sdk_cloudtrail::types::LookupAttributeKey::AccessKeyId)
74                .attribute_value(access_key_id)
75                .build()
76                .map_err(|e| AvError::Sts(format!("CloudTrail attribute build error: {}", e)))?,
77        )
78        .start_time(start)
79        .max_results(50)
80        .send()
81        .await
82        .map_err(|e| AvError::Sts(format!("CloudTrail LookupEvents error: {}", e)))?;
83
84    let events = result.events().iter().filter_map(|e| {
85        let event_name = e.event_name()?.to_string();
86
87        // Parse service from event source (e.g. "s3.amazonaws.com" -> "s3")
88        let service = e.event_source()
89            .map(|s| s.split('.').next().unwrap_or(s).to_string())
90            .unwrap_or_else(|| "unknown".to_string());
91
92        let read_only = e.read_only()
93            .and_then(|r| r.parse::<bool>().ok())
94            .unwrap_or(true);
95
96        // Parse error info from CloudTrail event JSON if available
97        let (error_code, error_message) = e.cloud_trail_event()
98            .and_then(|json| serde_json::from_str::<serde_json::Value>(json).ok())
99            .map(|v| {
100                let code = v.get("errorCode").and_then(|c| c.as_str()).map(String::from);
101                let msg = v.get("errorMessage").and_then(|m| m.as_str()).map(String::from);
102                (code, msg)
103            })
104            .unwrap_or((None, None));
105
106        let timestamp = e.event_time().map(|t| {
107            DateTime::from_timestamp(t.secs(), t.subsec_nanos())
108                .unwrap_or_else(Utc::now)
109        }).unwrap_or_else(Utc::now);
110
111        let source_ip = e.cloud_trail_event()
112            .and_then(|json| serde_json::from_str::<serde_json::Value>(json).ok())
113            .and_then(|v| v.get("sourceIPAddress").and_then(|s| s.as_str()).map(String::from));
114
115        let user_agent = e.cloud_trail_event()
116            .and_then(|json| serde_json::from_str::<serde_json::Value>(json).ok())
117            .and_then(|v| v.get("userAgent").and_then(|s| s.as_str()).map(String::from));
118
119        Some(ApiEvent {
120            timestamp,
121            event_name,
122            service,
123            source_ip,
124            user_agent,
125            error_code,
126            error_message,
127            read_only,
128        })
129    }).collect();
130
131    Ok(events)
132}
133
134/// Check if a session is still watchable (active and not expired).
135pub fn is_watchable(session: &Session) -> bool {
136    session.status == SessionStatus::Active && !session.is_expired()
137}
138
139/// Format an event as a colored terminal line.
140pub fn format_event(event: &ApiEvent) -> String {
141    let ts = event.timestamp.format("%H:%M:%S");
142    let rw = if event.read_only { "R" } else { "W" };
143    let status = if let Some(ref code) = event.error_code {
144        format!(" [ERR: {}]", code)
145    } else {
146        String::new()
147    };
148    format!(
149        "[{}] {} {:<12} {}{}",
150        ts, rw, event.service, event.event_name, status
151    )
152}
153
154/// Format a summary of the current watch state.
155pub fn format_summary(state: &WatchState, session: &Session) -> String {
156    let mut out = String::new();
157    out.push_str(&format!(
158        "Session {} | {} API calls | {} errors | {} services | {}s remaining\n",
159        &session.id[..8],
160        state.total_api_calls,
161        state.error_count,
162        state.services_used.len(),
163        session.remaining_seconds()
164    ));
165    if !state.services_used.is_empty() {
166        out.push_str(&format!("Services: {}\n", state.services_used.join(", ")));
167    }
168    out
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174
175    fn make_event(name: &str, service: &str, error: Option<&str>) -> ApiEvent {
176        ApiEvent {
177            timestamp: Utc::now(),
178            event_name: name.to_string(),
179            service: service.to_string(),
180            source_ip: Some("10.0.0.1".to_string()),
181            user_agent: Some("aws-cli/2.0".to_string()),
182            error_code: error.map(String::from),
183            error_message: None,
184            read_only: true,
185        }
186    }
187
188    #[test]
189    fn test_watch_state_merge() {
190        let mut state = WatchState::default();
191        let events = vec![
192            make_event("GetObject", "s3", None),
193            make_event("ListBuckets", "s3", None),
194        ];
195        state.merge(events);
196        assert_eq!(state.total_api_calls, 2);
197        assert_eq!(state.error_count, 0);
198        assert_eq!(state.services_used, vec!["s3"]);
199    }
200
201    #[test]
202    fn test_watch_state_dedup() {
203        let mut state = WatchState::default();
204        let ts = Utc::now();
205        let event = ApiEvent {
206            timestamp: ts,
207            event_name: "GetObject".to_string(),
208            service: "s3".to_string(),
209            source_ip: None,
210            user_agent: None,
211            error_code: None,
212            error_message: None,
213            read_only: true,
214        };
215        state.merge(vec![event.clone()]);
216        state.merge(vec![event]);
217        assert_eq!(state.total_api_calls, 1); // deduped
218    }
219
220    #[test]
221    fn test_watch_state_error_count() {
222        let mut state = WatchState::default();
223        let events = vec![
224            make_event("GetObject", "s3", None),
225            make_event("PutObject", "s3", Some("AccessDenied")),
226        ];
227        state.merge(events);
228        assert_eq!(state.error_count, 1);
229    }
230
231    #[test]
232    fn test_watch_state_multi_service() {
233        let mut state = WatchState::default();
234        let events = vec![
235            make_event("GetObject", "s3", None),
236            make_event("GetFunction", "lambda", None),
237            make_event("ListBuckets", "s3", None),
238        ];
239        state.merge(events);
240        assert_eq!(state.services_used.len(), 2);
241        assert!(state.services_used.contains(&"s3".to_string()));
242        assert!(state.services_used.contains(&"lambda".to_string()));
243    }
244
245    #[test]
246    fn test_format_event() {
247        let event = make_event("GetObject", "s3", None);
248        let line = format_event(&event);
249        assert!(line.contains("s3"));
250        assert!(line.contains("GetObject"));
251        assert!(line.contains("R")); // read-only
252    }
253
254    #[test]
255    fn test_format_event_with_error() {
256        let event = make_event("PutObject", "s3", Some("AccessDenied"));
257        let line = format_event(&event);
258        assert!(line.contains("AccessDenied"));
259    }
260
261    #[test]
262    fn test_format_summary() {
263        let mut state = WatchState::default();
264        state.merge(vec![
265            make_event("GetObject", "s3", None),
266            make_event("GetFunction", "lambda", None),
267        ]);
268        let session = crate::session::Session::new(
269            std::time::Duration::from_secs(900),
270            None,
271            crate::policy::ScopedPolicy::from_allow_str("s3:GetObject").unwrap(),
272            "arn:aws:iam::123456789012:role/Test".to_string(),
273            vec!["test".to_string()],
274        );
275        let summary = format_summary(&state, &session);
276        assert!(summary.contains("2 API calls"));
277        assert!(summary.contains("2 services"));
278    }
279}