1use std::sync::Arc;
11
12use axum::extract::ws::{Message, WebSocket};
13use axum::extract::{Query, State, WebSocketUpgrade};
14use axum::http::{HeaderMap, StatusCode};
15use axum::response::{IntoResponse, Response};
16use serde::{Deserialize, Serialize};
17use serde_json::json;
18use tuitbot_core::auth::session;
19
20use crate::state::AppState;
21
22#[derive(Clone, Debug, Serialize, Deserialize)]
24#[serde(tag = "type")]
25pub enum WsEvent {
26 ActionPerformed {
28 action_type: String,
29 target: String,
30 content: String,
31 timestamp: String,
32 },
33 ApprovalQueued {
35 id: i64,
36 action_type: String,
37 content: String,
38 #[serde(default)]
39 media_paths: Vec<String>,
40 },
41 ApprovalUpdated {
43 id: i64,
44 status: String,
45 action_type: String,
46 #[serde(skip_serializing_if = "Option::is_none")]
47 actor: Option<String>,
48 },
49 FollowerUpdate { count: i64, change: i64 },
51 RuntimeStatus {
53 running: bool,
54 active_loops: Vec<String>,
55 },
56 TweetDiscovered {
58 tweet_id: String,
59 author: String,
60 score: f64,
61 timestamp: String,
62 },
63 ActionSkipped {
65 action_type: String,
66 reason: String,
67 timestamp: String,
68 },
69 ContentScheduled {
71 id: i64,
72 content_type: String,
73 scheduled_for: Option<String>,
74 },
75 CircuitBreakerTripped {
77 state: String,
78 error_count: u32,
79 cooldown_remaining_seconds: u64,
80 timestamp: String,
81 },
82 Error { message: String },
84}
85
86#[derive(Deserialize)]
88pub struct WsQuery {
89 pub token: Option<String>,
91}
92
93fn extract_session_cookie(headers: &HeaderMap) -> Option<String> {
95 headers
96 .get("cookie")
97 .and_then(|v| v.to_str().ok())
98 .and_then(|cookies| {
99 cookies.split(';').find_map(|c| {
100 let c = c.trim();
101 c.strip_prefix("tuitbot_session=").map(|v| v.to_string())
102 })
103 })
104}
105
106pub async fn ws_handler(
108 ws: WebSocketUpgrade,
109 State(state): State<Arc<AppState>>,
110 headers: HeaderMap,
111 Query(params): Query<WsQuery>,
112) -> Response {
113 if let Some(ref token) = params.token {
115 if token == &state.api_token {
116 return ws.on_upgrade(move |socket| handle_ws(socket, state));
117 }
118 }
119
120 if let Some(session_token) = extract_session_cookie(&headers) {
122 if let Ok(Some(_)) = session::validate_session(&state.db, &session_token).await {
123 return ws.on_upgrade(move |socket| handle_ws(socket, state));
124 }
125 }
126
127 (
128 StatusCode::UNAUTHORIZED,
129 axum::Json(json!({"error": "unauthorized"})),
130 )
131 .into_response()
132}
133
134async fn handle_ws(mut socket: WebSocket, state: Arc<AppState>) {
138 let mut rx = state.event_tx.subscribe();
139
140 loop {
141 match rx.recv().await {
142 Ok(event) => {
143 let json = match serde_json::to_string(&event) {
144 Ok(j) => j,
145 Err(e) => {
146 tracing::error!(error = %e, "Failed to serialize WsEvent");
147 continue;
148 }
149 };
150 if socket.send(Message::Text(json.into())).await.is_err() {
151 break;
153 }
154 }
155 Err(tokio::sync::broadcast::error::RecvError::Lagged(count)) => {
156 tracing::warn!(count, "WebSocket client lagged, events dropped");
157 let error_event = WsEvent::Error {
158 message: format!("{count} events dropped due to slow consumer"),
159 };
160 if let Ok(json) = serde_json::to_string(&error_event) {
161 if socket.send(Message::Text(json.into())).await.is_err() {
162 break;
163 }
164 }
165 }
166 Err(tokio::sync::broadcast::error::RecvError::Closed) => {
167 break;
168 }
169 }
170 }
171}