1use std::sync::Arc;
7
8use axum::extract::ws::{Message, WebSocket};
9use axum::extract::{Query, State, WebSocketUpgrade};
10use axum::http::StatusCode;
11use axum::response::{IntoResponse, Response};
12use serde::{Deserialize, Serialize};
13use serde_json::json;
14
15use crate::state::AppState;
16
17#[derive(Clone, Debug, Serialize, Deserialize)]
19#[serde(tag = "type")]
20pub enum WsEvent {
21 ActionPerformed {
23 action_type: String,
24 target: String,
25 content: String,
26 timestamp: String,
27 },
28 ApprovalQueued {
30 id: i64,
31 action_type: String,
32 content: String,
33 },
34 FollowerUpdate { count: i64, change: i64 },
36 RuntimeStatus {
38 running: bool,
39 active_loops: Vec<String>,
40 },
41 TweetDiscovered {
43 tweet_id: String,
44 author: String,
45 score: f64,
46 timestamp: String,
47 },
48 ActionSkipped {
50 action_type: String,
51 reason: String,
52 timestamp: String,
53 },
54 Error { message: String },
56}
57
58#[derive(Deserialize)]
60pub struct WsQuery {
61 pub token: String,
63}
64
65pub async fn ws_handler(
67 ws: WebSocketUpgrade,
68 State(state): State<Arc<AppState>>,
69 Query(params): Query<WsQuery>,
70) -> Response {
71 if params.token != state.api_token {
73 return (
74 StatusCode::UNAUTHORIZED,
75 axum::Json(json!({"error": "unauthorized"})),
76 )
77 .into_response();
78 }
79
80 ws.on_upgrade(move |socket| handle_ws(socket, state))
81}
82
83async fn handle_ws(mut socket: WebSocket, state: Arc<AppState>) {
87 let mut rx = state.event_tx.subscribe();
88
89 loop {
90 match rx.recv().await {
91 Ok(event) => {
92 let json = match serde_json::to_string(&event) {
93 Ok(j) => j,
94 Err(e) => {
95 tracing::error!(error = %e, "Failed to serialize WsEvent");
96 continue;
97 }
98 };
99 if socket.send(Message::Text(json.into())).await.is_err() {
100 break;
102 }
103 }
104 Err(tokio::sync::broadcast::error::RecvError::Lagged(count)) => {
105 tracing::warn!(count, "WebSocket client lagged, events dropped");
106 let error_event = WsEvent::Error {
107 message: format!("{count} events dropped due to slow consumer"),
108 };
109 if let Ok(json) = serde_json::to_string(&error_event) {
110 if socket.send(Message::Text(json.into())).await.is_err() {
111 break;
112 }
113 }
114 }
115 Err(tokio::sync::broadcast::error::RecvError::Closed) => {
116 break;
117 }
118 }
119 }
120}