1use std::sync::Arc;
7
8use axum::extract::ws::{Message, WebSocket};
9use axum::extract::{Query, State, WebSocketUpgrade};
10use axum::http::StatusCode;
11use axum::response::{IntoResponse, Response};
12use serde::{Deserialize, Serialize};
13use serde_json::json;
14
15use crate::state::AppState;
16
17#[derive(Clone, Debug, Serialize, Deserialize)]
19#[serde(tag = "type")]
20pub enum WsEvent {
21 ActionPerformed {
23 action_type: String,
24 target: String,
25 content: String,
26 timestamp: String,
27 },
28 ApprovalQueued {
30 id: i64,
31 action_type: String,
32 content: String,
33 },
34 ApprovalUpdated {
36 id: i64,
37 status: String,
38 action_type: String,
39 },
40 FollowerUpdate { count: i64, change: i64 },
42 RuntimeStatus {
44 running: bool,
45 active_loops: Vec<String>,
46 },
47 TweetDiscovered {
49 tweet_id: String,
50 author: String,
51 score: f64,
52 timestamp: String,
53 },
54 ActionSkipped {
56 action_type: String,
57 reason: String,
58 timestamp: String,
59 },
60 ContentScheduled {
62 id: i64,
63 content_type: String,
64 scheduled_for: Option<String>,
65 },
66 Error { message: String },
68}
69
70#[derive(Deserialize)]
72pub struct WsQuery {
73 pub token: String,
75}
76
77pub async fn ws_handler(
79 ws: WebSocketUpgrade,
80 State(state): State<Arc<AppState>>,
81 Query(params): Query<WsQuery>,
82) -> Response {
83 if params.token != state.api_token {
85 return (
86 StatusCode::UNAUTHORIZED,
87 axum::Json(json!({"error": "unauthorized"})),
88 )
89 .into_response();
90 }
91
92 ws.on_upgrade(move |socket| handle_ws(socket, state))
93}
94
95async fn handle_ws(mut socket: WebSocket, state: Arc<AppState>) {
99 let mut rx = state.event_tx.subscribe();
100
101 loop {
102 match rx.recv().await {
103 Ok(event) => {
104 let json = match serde_json::to_string(&event) {
105 Ok(j) => j,
106 Err(e) => {
107 tracing::error!(error = %e, "Failed to serialize WsEvent");
108 continue;
109 }
110 };
111 if socket.send(Message::Text(json.into())).await.is_err() {
112 break;
114 }
115 }
116 Err(tokio::sync::broadcast::error::RecvError::Lagged(count)) => {
117 tracing::warn!(count, "WebSocket client lagged, events dropped");
118 let error_event = WsEvent::Error {
119 message: format!("{count} events dropped due to slow consumer"),
120 };
121 if let Ok(json) = serde_json::to_string(&error_event) {
122 if socket.send(Message::Text(json.into())).await.is_err() {
123 break;
124 }
125 }
126 }
127 Err(tokio::sync::broadcast::error::RecvError::Closed) => {
128 break;
129 }
130 }
131 }
132}