1use std::sync::Arc;
11
12use axum::extract::ws::{Message, WebSocket};
13use axum::extract::{Query, State, WebSocketUpgrade};
14use axum::http::{HeaderMap, StatusCode};
15use axum::response::{IntoResponse, Response};
16use serde::{Deserialize, Serialize};
17use serde_json::json;
18use tuitbot_core::auth::session;
19
20use crate::state::AppState;
21
22#[derive(Clone, Debug, Serialize, Deserialize)]
27pub struct AccountWsEvent {
28 pub account_id: String,
29 #[serde(flatten)]
30 pub event: WsEvent,
31}
32
33#[derive(Clone, Debug, Serialize, Deserialize)]
35#[serde(tag = "type")]
36pub enum WsEvent {
37 ActionPerformed {
39 action_type: String,
40 target: String,
41 content: String,
42 timestamp: String,
43 },
44 ApprovalQueued {
46 id: i64,
47 action_type: String,
48 content: String,
49 #[serde(default)]
50 media_paths: Vec<String>,
51 },
52 ApprovalUpdated {
54 id: i64,
55 status: String,
56 action_type: String,
57 #[serde(skip_serializing_if = "Option::is_none")]
58 actor: Option<String>,
59 },
60 FollowerUpdate { count: i64, change: i64 },
62 RuntimeStatus {
64 running: bool,
65 active_loops: Vec<String>,
66 },
67 TweetDiscovered {
69 tweet_id: String,
70 author: String,
71 score: f64,
72 timestamp: String,
73 },
74 ActionSkipped {
76 action_type: String,
77 reason: String,
78 timestamp: String,
79 },
80 ContentScheduled {
82 id: i64,
83 content_type: String,
84 scheduled_for: Option<String>,
85 },
86 CircuitBreakerTripped {
88 state: String,
89 error_count: u32,
90 cooldown_remaining_seconds: u64,
91 timestamp: String,
92 },
93 SelectionReceived { session_id: String },
95 Error { message: String },
97}
98
99#[derive(Deserialize)]
101pub struct WsQuery {
102 pub token: Option<String>,
104}
105
106fn extract_session_cookie(headers: &HeaderMap) -> Option<String> {
108 headers
109 .get("cookie")
110 .and_then(|v| v.to_str().ok())
111 .and_then(|cookies| {
112 cookies.split(';').find_map(|c| {
113 let c = c.trim();
114 c.strip_prefix("tuitbot_session=").map(|v| v.to_string())
115 })
116 })
117}
118
119pub async fn ws_handler(
121 ws: WebSocketUpgrade,
122 State(state): State<Arc<AppState>>,
123 headers: HeaderMap,
124 Query(params): Query<WsQuery>,
125) -> Response {
126 if let Some(ref token) = params.token {
128 if token == &state.api_token {
129 return ws.on_upgrade(move |socket| handle_ws(socket, state));
130 }
131 }
132
133 if let Some(session_token) = extract_session_cookie(&headers) {
135 if let Ok(Some(_)) = session::validate_session(&state.db, &session_token).await {
136 return ws.on_upgrade(move |socket| handle_ws(socket, state));
137 }
138 }
139
140 (
141 StatusCode::UNAUTHORIZED,
142 axum::Json(json!({"error": "unauthorized"})),
143 )
144 .into_response()
145}
146
147async fn handle_ws(mut socket: WebSocket, state: Arc<AppState>) {
151 let mut rx = state.event_tx.subscribe();
152
153 loop {
154 match rx.recv().await {
155 Ok(event) => {
156 let json = match serde_json::to_string(&event) {
157 Ok(j) => j,
158 Err(e) => {
159 tracing::error!(error = %e, "Failed to serialize WsEvent");
160 continue;
161 }
162 };
163 if socket.send(Message::Text(json.into())).await.is_err() {
164 break;
166 }
167 }
168 Err(tokio::sync::broadcast::error::RecvError::Lagged(count)) => {
169 tracing::warn!(count, "WebSocket client lagged, events dropped");
170 let error_event = AccountWsEvent {
171 account_id: String::new(),
172 event: WsEvent::Error {
173 message: format!("{count} events dropped due to slow consumer"),
174 },
175 };
176 if let Ok(json) = serde_json::to_string(&error_event) {
177 if socket.send(Message::Text(json.into())).await.is_err() {
178 break;
179 }
180 }
181 }
182 Err(tokio::sync::broadcast::error::RecvError::Closed) => {
183 break;
184 }
185 }
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192
193 #[test]
196 fn action_performed_serializes_with_type_tag() {
197 let event = WsEvent::ActionPerformed {
198 action_type: "reply".into(),
199 target: "@user".into(),
200 content: "Hello!".into(),
201 timestamp: "2026-03-15T12:00:00Z".into(),
202 };
203 let json = serde_json::to_value(&event).unwrap();
204 assert_eq!(json["type"], "ActionPerformed");
205 assert_eq!(json["action_type"], "reply");
206 assert_eq!(json["target"], "@user");
207 }
208
209 #[test]
210 fn approval_queued_serializes() {
211 let event = WsEvent::ApprovalQueued {
212 id: 42,
213 action_type: "tweet".into(),
214 content: "Draft tweet".into(),
215 media_paths: vec!["img.png".into()],
216 };
217 let json = serde_json::to_value(&event).unwrap();
218 assert_eq!(json["type"], "ApprovalQueued");
219 assert_eq!(json["id"], 42);
220 assert_eq!(json["media_paths"].as_array().unwrap().len(), 1);
221 }
222
223 #[test]
224 fn approval_updated_serializes_with_optional_actor() {
225 let event = WsEvent::ApprovalUpdated {
226 id: 1,
227 status: "approved".into(),
228 action_type: "tweet".into(),
229 actor: Some("admin".into()),
230 };
231 let json = serde_json::to_value(&event).unwrap();
232 assert_eq!(json["actor"], "admin");
233
234 let event_no_actor = WsEvent::ApprovalUpdated {
235 id: 1,
236 status: "rejected".into(),
237 action_type: "tweet".into(),
238 actor: None,
239 };
240 let json2 = serde_json::to_value(&event_no_actor).unwrap();
241 assert!(
242 json2.get("actor").is_none(),
243 "actor should be skipped when None"
244 );
245 }
246
247 #[test]
248 fn follower_update_serializes() {
249 let event = WsEvent::FollowerUpdate {
250 count: 1500,
251 change: 25,
252 };
253 let json = serde_json::to_value(&event).unwrap();
254 assert_eq!(json["type"], "FollowerUpdate");
255 assert_eq!(json["count"], 1500);
256 assert_eq!(json["change"], 25);
257 }
258
259 #[test]
260 fn runtime_status_serializes() {
261 let event = WsEvent::RuntimeStatus {
262 running: true,
263 active_loops: vec!["mentions".into(), "discovery".into()],
264 };
265 let json = serde_json::to_value(&event).unwrap();
266 assert_eq!(json["type"], "RuntimeStatus");
267 assert_eq!(json["running"], true);
268 assert_eq!(json["active_loops"].as_array().unwrap().len(), 2);
269 }
270
271 #[test]
272 fn tweet_discovered_serializes() {
273 let event = WsEvent::TweetDiscovered {
274 tweet_id: "123456".into(),
275 author: "user1".into(),
276 score: 0.95,
277 timestamp: "2026-03-15T12:00:00Z".into(),
278 };
279 let json = serde_json::to_value(&event).unwrap();
280 assert_eq!(json["type"], "TweetDiscovered");
281 assert_eq!(json["score"], 0.95);
282 }
283
284 #[test]
285 fn action_skipped_serializes() {
286 let event = WsEvent::ActionSkipped {
287 action_type: "reply".into(),
288 reason: "rate limited".into(),
289 timestamp: "2026-03-15T12:00:00Z".into(),
290 };
291 let json = serde_json::to_value(&event).unwrap();
292 assert_eq!(json["type"], "ActionSkipped");
293 assert_eq!(json["reason"], "rate limited");
294 }
295
296 #[test]
297 fn content_scheduled_serializes() {
298 let event = WsEvent::ContentScheduled {
299 id: 7,
300 content_type: "tweet".into(),
301 scheduled_for: Some("2026-03-16T09:00:00Z".into()),
302 };
303 let json = serde_json::to_value(&event).unwrap();
304 assert_eq!(json["type"], "ContentScheduled");
305 assert_eq!(json["id"], 7);
306 assert!(json["scheduled_for"].is_string());
307 }
308
309 #[test]
310 fn circuit_breaker_tripped_serializes() {
311 let event = WsEvent::CircuitBreakerTripped {
312 state: "open".into(),
313 error_count: 5,
314 cooldown_remaining_seconds: 120,
315 timestamp: "2026-03-15T12:00:00Z".into(),
316 };
317 let json = serde_json::to_value(&event).unwrap();
318 assert_eq!(json["type"], "CircuitBreakerTripped");
319 assert_eq!(json["error_count"], 5);
320 assert_eq!(json["cooldown_remaining_seconds"], 120);
321 }
322
323 #[test]
324 fn error_event_serializes() {
325 let event = WsEvent::Error {
326 message: "something broke".into(),
327 };
328 let json = serde_json::to_value(&event).unwrap();
329 assert_eq!(json["type"], "Error");
330 assert_eq!(json["message"], "something broke");
331 }
332
333 #[test]
336 fn account_ws_event_flattens_correctly() {
337 let event = AccountWsEvent {
338 account_id: "acct-123".into(),
339 event: WsEvent::FollowerUpdate {
340 count: 100,
341 change: 5,
342 },
343 };
344 let json = serde_json::to_value(&event).unwrap();
345 assert_eq!(json["account_id"], "acct-123");
346 assert_eq!(json["type"], "FollowerUpdate");
347 assert_eq!(json["count"], 100);
348 }
349
350 #[test]
351 fn account_ws_event_roundtrip() {
352 let original = AccountWsEvent {
353 account_id: "acct-456".into(),
354 event: WsEvent::Error {
355 message: "test error".into(),
356 },
357 };
358 let json_str = serde_json::to_string(&original).unwrap();
359 let deserialized: AccountWsEvent = serde_json::from_str(&json_str).unwrap();
360 assert_eq!(deserialized.account_id, "acct-456");
361 match deserialized.event {
362 WsEvent::Error { message } => assert_eq!(message, "test error"),
363 _ => panic!("expected Error variant"),
364 }
365 }
366
367 #[test]
370 fn extract_session_cookie_present() {
371 let mut headers = HeaderMap::new();
372 headers.insert(
373 "cookie",
374 "other=foo; tuitbot_session=abc123; another=bar"
375 .parse()
376 .unwrap(),
377 );
378 let result = extract_session_cookie(&headers);
379 assert_eq!(result.as_deref(), Some("abc123"));
380 }
381
382 #[test]
383 fn extract_session_cookie_not_present() {
384 let mut headers = HeaderMap::new();
385 headers.insert("cookie", "other=foo; another=bar".parse().unwrap());
386 let result = extract_session_cookie(&headers);
387 assert!(result.is_none());
388 }
389
390 #[test]
391 fn extract_session_cookie_no_cookie_header() {
392 let headers = HeaderMap::new();
393 let result = extract_session_cookie(&headers);
394 assert!(result.is_none());
395 }
396
397 #[test]
400 fn ws_event_deserializes_from_json() {
401 let json = r#"{"type":"Error","message":"test"}"#;
402 let event: WsEvent = serde_json::from_str(json).unwrap();
403 match event {
404 WsEvent::Error { message } => assert_eq!(message, "test"),
405 _ => panic!("expected Error variant"),
406 }
407 }
408
409 #[test]
410 fn all_event_variants_serialize_without_panic() {
411 let events: Vec<WsEvent> = vec![
412 WsEvent::ActionPerformed {
413 action_type: "reply".into(),
414 target: "t".into(),
415 content: "c".into(),
416 timestamp: "ts".into(),
417 },
418 WsEvent::ApprovalQueued {
419 id: 1,
420 action_type: "tweet".into(),
421 content: "c".into(),
422 media_paths: vec![],
423 },
424 WsEvent::ApprovalUpdated {
425 id: 1,
426 status: "s".into(),
427 action_type: "a".into(),
428 actor: None,
429 },
430 WsEvent::FollowerUpdate {
431 count: 0,
432 change: 0,
433 },
434 WsEvent::RuntimeStatus {
435 running: false,
436 active_loops: vec![],
437 },
438 WsEvent::TweetDiscovered {
439 tweet_id: "t".into(),
440 author: "a".into(),
441 score: 0.0,
442 timestamp: "ts".into(),
443 },
444 WsEvent::ActionSkipped {
445 action_type: "a".into(),
446 reason: "r".into(),
447 timestamp: "ts".into(),
448 },
449 WsEvent::ContentScheduled {
450 id: 1,
451 content_type: "tweet".into(),
452 scheduled_for: None,
453 },
454 WsEvent::CircuitBreakerTripped {
455 state: "open".into(),
456 error_count: 0,
457 cooldown_remaining_seconds: 0,
458 timestamp: "ts".into(),
459 },
460 WsEvent::SelectionReceived {
461 session_id: "sess-1".into(),
462 },
463 WsEvent::Error {
464 message: "err".into(),
465 },
466 ];
467 for event in events {
468 let json = serde_json::to_string(&event).unwrap();
469 assert!(!json.is_empty());
470 let _: WsEvent = serde_json::from_str(&json).unwrap();
472 }
473 }
474}