1use std::sync::Arc;
7
8use axum::extract::ws::{Message, WebSocket};
9use axum::extract::{Query, State, WebSocketUpgrade};
10use axum::http::StatusCode;
11use axum::response::{IntoResponse, Response};
12use serde::{Deserialize, Serialize};
13use serde_json::json;
14
15use crate::state::AppState;
16
17#[derive(Clone, Debug, Serialize, Deserialize)]
19#[serde(tag = "type")]
20pub enum WsEvent {
21 ActionPerformed {
23 action_type: String,
24 target: String,
25 content: String,
26 timestamp: String,
27 },
28 ApprovalQueued {
30 id: i64,
31 action_type: String,
32 content: String,
33 #[serde(default)]
34 media_paths: Vec<String>,
35 },
36 ApprovalUpdated {
38 id: i64,
39 status: String,
40 action_type: String,
41 #[serde(skip_serializing_if = "Option::is_none")]
42 actor: Option<String>,
43 },
44 FollowerUpdate { count: i64, change: i64 },
46 RuntimeStatus {
48 running: bool,
49 active_loops: Vec<String>,
50 },
51 TweetDiscovered {
53 tweet_id: String,
54 author: String,
55 score: f64,
56 timestamp: String,
57 },
58 ActionSkipped {
60 action_type: String,
61 reason: String,
62 timestamp: String,
63 },
64 ContentScheduled {
66 id: i64,
67 content_type: String,
68 scheduled_for: Option<String>,
69 },
70 CircuitBreakerTripped {
72 state: String,
73 error_count: u32,
74 cooldown_remaining_seconds: u64,
75 timestamp: String,
76 },
77 Error { message: String },
79}
80
81#[derive(Deserialize)]
83pub struct WsQuery {
84 pub token: String,
86}
87
88pub async fn ws_handler(
90 ws: WebSocketUpgrade,
91 State(state): State<Arc<AppState>>,
92 Query(params): Query<WsQuery>,
93) -> Response {
94 if params.token != state.api_token {
96 return (
97 StatusCode::UNAUTHORIZED,
98 axum::Json(json!({"error": "unauthorized"})),
99 )
100 .into_response();
101 }
102
103 ws.on_upgrade(move |socket| handle_ws(socket, state))
104}
105
106async fn handle_ws(mut socket: WebSocket, state: Arc<AppState>) {
110 let mut rx = state.event_tx.subscribe();
111
112 loop {
113 match rx.recv().await {
114 Ok(event) => {
115 let json = match serde_json::to_string(&event) {
116 Ok(j) => j,
117 Err(e) => {
118 tracing::error!(error = %e, "Failed to serialize WsEvent");
119 continue;
120 }
121 };
122 if socket.send(Message::Text(json.into())).await.is_err() {
123 break;
125 }
126 }
127 Err(tokio::sync::broadcast::error::RecvError::Lagged(count)) => {
128 tracing::warn!(count, "WebSocket client lagged, events dropped");
129 let error_event = WsEvent::Error {
130 message: format!("{count} events dropped due to slow consumer"),
131 };
132 if let Ok(json) = serde_json::to_string(&error_event) {
133 if socket.send(Message::Text(json.into())).await.is_err() {
134 break;
135 }
136 }
137 }
138 Err(tokio::sync::broadcast::error::RecvError::Closed) => {
139 break;
140 }
141 }
142 }
143}