1use std::sync::Arc;
11
12use axum::extract::ws::{Message, WebSocket};
13use axum::extract::{Query, State, WebSocketUpgrade};
14use axum::http::{HeaderMap, StatusCode};
15use axum::response::{IntoResponse, Response};
16use serde::{Deserialize, Serialize};
17use serde_json::json;
18use tuitbot_core::auth::session;
19
20use crate::state::AppState;
21
22#[derive(Clone, Debug, Serialize, Deserialize)]
27pub struct AccountWsEvent {
28 pub account_id: String,
29 #[serde(flatten)]
30 pub event: WsEvent,
31}
32
33#[derive(Clone, Debug, Serialize, Deserialize)]
35#[serde(tag = "type")]
36pub enum WsEvent {
37 ActionPerformed {
39 action_type: String,
40 target: String,
41 content: String,
42 timestamp: String,
43 },
44 ApprovalQueued {
46 id: i64,
47 action_type: String,
48 content: String,
49 #[serde(default)]
50 media_paths: Vec<String>,
51 },
52 ApprovalUpdated {
54 id: i64,
55 status: String,
56 action_type: String,
57 #[serde(skip_serializing_if = "Option::is_none")]
58 actor: Option<String>,
59 },
60 FollowerUpdate { count: i64, change: i64 },
62 RuntimeStatus {
64 running: bool,
65 active_loops: Vec<String>,
66 },
67 TweetDiscovered {
69 tweet_id: String,
70 author: String,
71 score: f64,
72 timestamp: String,
73 },
74 ActionSkipped {
76 action_type: String,
77 reason: String,
78 timestamp: String,
79 },
80 ContentScheduled {
82 id: i64,
83 content_type: String,
84 scheduled_for: Option<String>,
85 },
86 CircuitBreakerTripped {
88 state: String,
89 error_count: u32,
90 cooldown_remaining_seconds: u64,
91 timestamp: String,
92 },
93 Error { message: String },
95}
96
97#[derive(Deserialize)]
99pub struct WsQuery {
100 pub token: Option<String>,
102}
103
104fn extract_session_cookie(headers: &HeaderMap) -> Option<String> {
106 headers
107 .get("cookie")
108 .and_then(|v| v.to_str().ok())
109 .and_then(|cookies| {
110 cookies.split(';').find_map(|c| {
111 let c = c.trim();
112 c.strip_prefix("tuitbot_session=").map(|v| v.to_string())
113 })
114 })
115}
116
117pub async fn ws_handler(
119 ws: WebSocketUpgrade,
120 State(state): State<Arc<AppState>>,
121 headers: HeaderMap,
122 Query(params): Query<WsQuery>,
123) -> Response {
124 if let Some(ref token) = params.token {
126 if token == &state.api_token {
127 return ws.on_upgrade(move |socket| handle_ws(socket, state));
128 }
129 }
130
131 if let Some(session_token) = extract_session_cookie(&headers) {
133 if let Ok(Some(_)) = session::validate_session(&state.db, &session_token).await {
134 return ws.on_upgrade(move |socket| handle_ws(socket, state));
135 }
136 }
137
138 (
139 StatusCode::UNAUTHORIZED,
140 axum::Json(json!({"error": "unauthorized"})),
141 )
142 .into_response()
143}
144
145async fn handle_ws(mut socket: WebSocket, state: Arc<AppState>) {
149 let mut rx = state.event_tx.subscribe();
150
151 loop {
152 match rx.recv().await {
153 Ok(event) => {
154 let json = match serde_json::to_string(&event) {
155 Ok(j) => j,
156 Err(e) => {
157 tracing::error!(error = %e, "Failed to serialize WsEvent");
158 continue;
159 }
160 };
161 if socket.send(Message::Text(json.into())).await.is_err() {
162 break;
164 }
165 }
166 Err(tokio::sync::broadcast::error::RecvError::Lagged(count)) => {
167 tracing::warn!(count, "WebSocket client lagged, events dropped");
168 let error_event = AccountWsEvent {
169 account_id: String::new(),
170 event: WsEvent::Error {
171 message: format!("{count} events dropped due to slow consumer"),
172 },
173 };
174 if let Ok(json) = serde_json::to_string(&error_event) {
175 if socket.send(Message::Text(json.into())).await.is_err() {
176 break;
177 }
178 }
179 }
180 Err(tokio::sync::broadcast::error::RecvError::Closed) => {
181 break;
182 }
183 }
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190
191 #[test]
194 fn action_performed_serializes_with_type_tag() {
195 let event = WsEvent::ActionPerformed {
196 action_type: "reply".into(),
197 target: "@user".into(),
198 content: "Hello!".into(),
199 timestamp: "2026-03-15T12:00:00Z".into(),
200 };
201 let json = serde_json::to_value(&event).unwrap();
202 assert_eq!(json["type"], "ActionPerformed");
203 assert_eq!(json["action_type"], "reply");
204 assert_eq!(json["target"], "@user");
205 }
206
207 #[test]
208 fn approval_queued_serializes() {
209 let event = WsEvent::ApprovalQueued {
210 id: 42,
211 action_type: "tweet".into(),
212 content: "Draft tweet".into(),
213 media_paths: vec!["img.png".into()],
214 };
215 let json = serde_json::to_value(&event).unwrap();
216 assert_eq!(json["type"], "ApprovalQueued");
217 assert_eq!(json["id"], 42);
218 assert_eq!(json["media_paths"].as_array().unwrap().len(), 1);
219 }
220
221 #[test]
222 fn approval_updated_serializes_with_optional_actor() {
223 let event = WsEvent::ApprovalUpdated {
224 id: 1,
225 status: "approved".into(),
226 action_type: "tweet".into(),
227 actor: Some("admin".into()),
228 };
229 let json = serde_json::to_value(&event).unwrap();
230 assert_eq!(json["actor"], "admin");
231
232 let event_no_actor = WsEvent::ApprovalUpdated {
233 id: 1,
234 status: "rejected".into(),
235 action_type: "tweet".into(),
236 actor: None,
237 };
238 let json2 = serde_json::to_value(&event_no_actor).unwrap();
239 assert!(
240 json2.get("actor").is_none(),
241 "actor should be skipped when None"
242 );
243 }
244
245 #[test]
246 fn follower_update_serializes() {
247 let event = WsEvent::FollowerUpdate {
248 count: 1500,
249 change: 25,
250 };
251 let json = serde_json::to_value(&event).unwrap();
252 assert_eq!(json["type"], "FollowerUpdate");
253 assert_eq!(json["count"], 1500);
254 assert_eq!(json["change"], 25);
255 }
256
257 #[test]
258 fn runtime_status_serializes() {
259 let event = WsEvent::RuntimeStatus {
260 running: true,
261 active_loops: vec!["mentions".into(), "discovery".into()],
262 };
263 let json = serde_json::to_value(&event).unwrap();
264 assert_eq!(json["type"], "RuntimeStatus");
265 assert_eq!(json["running"], true);
266 assert_eq!(json["active_loops"].as_array().unwrap().len(), 2);
267 }
268
269 #[test]
270 fn tweet_discovered_serializes() {
271 let event = WsEvent::TweetDiscovered {
272 tweet_id: "123456".into(),
273 author: "user1".into(),
274 score: 0.95,
275 timestamp: "2026-03-15T12:00:00Z".into(),
276 };
277 let json = serde_json::to_value(&event).unwrap();
278 assert_eq!(json["type"], "TweetDiscovered");
279 assert_eq!(json["score"], 0.95);
280 }
281
282 #[test]
283 fn action_skipped_serializes() {
284 let event = WsEvent::ActionSkipped {
285 action_type: "reply".into(),
286 reason: "rate limited".into(),
287 timestamp: "2026-03-15T12:00:00Z".into(),
288 };
289 let json = serde_json::to_value(&event).unwrap();
290 assert_eq!(json["type"], "ActionSkipped");
291 assert_eq!(json["reason"], "rate limited");
292 }
293
294 #[test]
295 fn content_scheduled_serializes() {
296 let event = WsEvent::ContentScheduled {
297 id: 7,
298 content_type: "tweet".into(),
299 scheduled_for: Some("2026-03-16T09:00:00Z".into()),
300 };
301 let json = serde_json::to_value(&event).unwrap();
302 assert_eq!(json["type"], "ContentScheduled");
303 assert_eq!(json["id"], 7);
304 assert!(json["scheduled_for"].is_string());
305 }
306
307 #[test]
308 fn circuit_breaker_tripped_serializes() {
309 let event = WsEvent::CircuitBreakerTripped {
310 state: "open".into(),
311 error_count: 5,
312 cooldown_remaining_seconds: 120,
313 timestamp: "2026-03-15T12:00:00Z".into(),
314 };
315 let json = serde_json::to_value(&event).unwrap();
316 assert_eq!(json["type"], "CircuitBreakerTripped");
317 assert_eq!(json["error_count"], 5);
318 assert_eq!(json["cooldown_remaining_seconds"], 120);
319 }
320
321 #[test]
322 fn error_event_serializes() {
323 let event = WsEvent::Error {
324 message: "something broke".into(),
325 };
326 let json = serde_json::to_value(&event).unwrap();
327 assert_eq!(json["type"], "Error");
328 assert_eq!(json["message"], "something broke");
329 }
330
331 #[test]
334 fn account_ws_event_flattens_correctly() {
335 let event = AccountWsEvent {
336 account_id: "acct-123".into(),
337 event: WsEvent::FollowerUpdate {
338 count: 100,
339 change: 5,
340 },
341 };
342 let json = serde_json::to_value(&event).unwrap();
343 assert_eq!(json["account_id"], "acct-123");
344 assert_eq!(json["type"], "FollowerUpdate");
345 assert_eq!(json["count"], 100);
346 }
347
348 #[test]
349 fn account_ws_event_roundtrip() {
350 let original = AccountWsEvent {
351 account_id: "acct-456".into(),
352 event: WsEvent::Error {
353 message: "test error".into(),
354 },
355 };
356 let json_str = serde_json::to_string(&original).unwrap();
357 let deserialized: AccountWsEvent = serde_json::from_str(&json_str).unwrap();
358 assert_eq!(deserialized.account_id, "acct-456");
359 match deserialized.event {
360 WsEvent::Error { message } => assert_eq!(message, "test error"),
361 _ => panic!("expected Error variant"),
362 }
363 }
364
365 #[test]
368 fn extract_session_cookie_present() {
369 let mut headers = HeaderMap::new();
370 headers.insert(
371 "cookie",
372 "other=foo; tuitbot_session=abc123; another=bar"
373 .parse()
374 .unwrap(),
375 );
376 let result = extract_session_cookie(&headers);
377 assert_eq!(result.as_deref(), Some("abc123"));
378 }
379
380 #[test]
381 fn extract_session_cookie_not_present() {
382 let mut headers = HeaderMap::new();
383 headers.insert("cookie", "other=foo; another=bar".parse().unwrap());
384 let result = extract_session_cookie(&headers);
385 assert!(result.is_none());
386 }
387
388 #[test]
389 fn extract_session_cookie_no_cookie_header() {
390 let headers = HeaderMap::new();
391 let result = extract_session_cookie(&headers);
392 assert!(result.is_none());
393 }
394
395 #[test]
398 fn ws_event_deserializes_from_json() {
399 let json = r#"{"type":"Error","message":"test"}"#;
400 let event: WsEvent = serde_json::from_str(json).unwrap();
401 match event {
402 WsEvent::Error { message } => assert_eq!(message, "test"),
403 _ => panic!("expected Error variant"),
404 }
405 }
406
407 #[test]
408 fn all_event_variants_serialize_without_panic() {
409 let events: Vec<WsEvent> = vec![
410 WsEvent::ActionPerformed {
411 action_type: "reply".into(),
412 target: "t".into(),
413 content: "c".into(),
414 timestamp: "ts".into(),
415 },
416 WsEvent::ApprovalQueued {
417 id: 1,
418 action_type: "tweet".into(),
419 content: "c".into(),
420 media_paths: vec![],
421 },
422 WsEvent::ApprovalUpdated {
423 id: 1,
424 status: "s".into(),
425 action_type: "a".into(),
426 actor: None,
427 },
428 WsEvent::FollowerUpdate {
429 count: 0,
430 change: 0,
431 },
432 WsEvent::RuntimeStatus {
433 running: false,
434 active_loops: vec![],
435 },
436 WsEvent::TweetDiscovered {
437 tweet_id: "t".into(),
438 author: "a".into(),
439 score: 0.0,
440 timestamp: "ts".into(),
441 },
442 WsEvent::ActionSkipped {
443 action_type: "a".into(),
444 reason: "r".into(),
445 timestamp: "ts".into(),
446 },
447 WsEvent::ContentScheduled {
448 id: 1,
449 content_type: "tweet".into(),
450 scheduled_for: None,
451 },
452 WsEvent::CircuitBreakerTripped {
453 state: "open".into(),
454 error_count: 0,
455 cooldown_remaining_seconds: 0,
456 timestamp: "ts".into(),
457 },
458 WsEvent::Error {
459 message: "err".into(),
460 },
461 ];
462 for event in events {
463 let json = serde_json::to_string(&event).unwrap();
464 assert!(!json.is_empty());
465 let _: WsEvent = serde_json::from_str(&json).unwrap();
467 }
468 }
469}