1use std::sync::Arc;
11
12use axum::extract::ws::{Message, WebSocket};
13use axum::extract::{Query, State, WebSocketUpgrade};
14use axum::http::{HeaderMap, StatusCode};
15use axum::response::{IntoResponse, Response};
16use serde::{Deserialize, Serialize};
17use serde_json::json;
18use tuitbot_core::auth::session;
19
20use crate::state::AppState;
21
22#[derive(Clone, Debug, Serialize, Deserialize)]
27pub struct AccountWsEvent {
28 pub account_id: String,
29 #[serde(flatten)]
30 pub event: WsEvent,
31}
32
33#[derive(Clone, Debug, Serialize, Deserialize)]
35#[serde(tag = "type")]
36pub enum WsEvent {
37 ActionPerformed {
39 action_type: String,
40 target: String,
41 content: String,
42 timestamp: String,
43 },
44 ApprovalQueued {
46 id: i64,
47 action_type: String,
48 content: String,
49 #[serde(default)]
50 media_paths: Vec<String>,
51 },
52 ApprovalUpdated {
54 id: i64,
55 status: String,
56 action_type: String,
57 #[serde(skip_serializing_if = "Option::is_none")]
58 actor: Option<String>,
59 },
60 FollowerUpdate { count: i64, change: i64 },
62 RuntimeStatus {
64 running: bool,
65 active_loops: Vec<String>,
66 },
67 TweetDiscovered {
69 tweet_id: String,
70 author: String,
71 score: f64,
72 timestamp: String,
73 },
74 ActionSkipped {
76 action_type: String,
77 reason: String,
78 timestamp: String,
79 },
80 ContentScheduled {
82 id: i64,
83 content_type: String,
84 scheduled_for: Option<String>,
85 },
86 CircuitBreakerTripped {
88 state: String,
89 error_count: u32,
90 cooldown_remaining_seconds: u64,
91 timestamp: String,
92 },
93 Error { message: String },
95}
96
97#[derive(Deserialize)]
99pub struct WsQuery {
100 pub token: Option<String>,
102}
103
104fn extract_session_cookie(headers: &HeaderMap) -> Option<String> {
106 headers
107 .get("cookie")
108 .and_then(|v| v.to_str().ok())
109 .and_then(|cookies| {
110 cookies.split(';').find_map(|c| {
111 let c = c.trim();
112 c.strip_prefix("tuitbot_session=").map(|v| v.to_string())
113 })
114 })
115}
116
117pub async fn ws_handler(
119 ws: WebSocketUpgrade,
120 State(state): State<Arc<AppState>>,
121 headers: HeaderMap,
122 Query(params): Query<WsQuery>,
123) -> Response {
124 if let Some(ref token) = params.token {
126 if token == &state.api_token {
127 return ws.on_upgrade(move |socket| handle_ws(socket, state));
128 }
129 }
130
131 if let Some(session_token) = extract_session_cookie(&headers) {
133 if let Ok(Some(_)) = session::validate_session(&state.db, &session_token).await {
134 return ws.on_upgrade(move |socket| handle_ws(socket, state));
135 }
136 }
137
138 (
139 StatusCode::UNAUTHORIZED,
140 axum::Json(json!({"error": "unauthorized"})),
141 )
142 .into_response()
143}
144
145async fn handle_ws(mut socket: WebSocket, state: Arc<AppState>) {
149 let mut rx = state.event_tx.subscribe();
150
151 loop {
152 match rx.recv().await {
153 Ok(event) => {
154 let json = match serde_json::to_string(&event) {
155 Ok(j) => j,
156 Err(e) => {
157 tracing::error!(error = %e, "Failed to serialize WsEvent");
158 continue;
159 }
160 };
161 if socket.send(Message::Text(json.into())).await.is_err() {
162 break;
164 }
165 }
166 Err(tokio::sync::broadcast::error::RecvError::Lagged(count)) => {
167 tracing::warn!(count, "WebSocket client lagged, events dropped");
168 let error_event = AccountWsEvent {
169 account_id: String::new(),
170 event: WsEvent::Error {
171 message: format!("{count} events dropped due to slow consumer"),
172 },
173 };
174 if let Ok(json) = serde_json::to_string(&error_event) {
175 if socket.send(Message::Text(json.into())).await.is_err() {
176 break;
177 }
178 }
179 }
180 Err(tokio::sync::broadcast::error::RecvError::Closed) => {
181 break;
182 }
183 }
184 }
185}