Skip to main content

tuitbot_server/
ws.rs

1//! WebSocket hub for real-time event streaming.
2//!
3//! Provides a `/api/ws` endpoint that streams server events to dashboard clients
4//! via a `tokio::sync::broadcast` channel.
5
6use std::sync::Arc;
7
8use axum::extract::ws::{Message, WebSocket};
9use axum::extract::{Query, State, WebSocketUpgrade};
10use axum::http::StatusCode;
11use axum::response::{IntoResponse, Response};
12use serde::{Deserialize, Serialize};
13use serde_json::json;
14
15use crate::state::AppState;
16
17/// Events pushed to WebSocket clients.
18#[derive(Clone, Debug, Serialize, Deserialize)]
19#[serde(tag = "type")]
20pub enum WsEvent {
21    /// An automation action was performed (reply, tweet, thread, etc.).
22    ActionPerformed {
23        action_type: String,
24        target: String,
25        content: String,
26        timestamp: String,
27    },
28    /// A new item was queued for approval.
29    ApprovalQueued {
30        id: i64,
31        action_type: String,
32        content: String,
33        #[serde(default)]
34        media_paths: Vec<String>,
35    },
36    /// An approval item's status was updated (approved, rejected, edited).
37    ApprovalUpdated {
38        id: i64,
39        status: String,
40        action_type: String,
41    },
42    /// Follower count changed.
43    FollowerUpdate { count: i64, change: i64 },
44    /// Automation runtime status changed.
45    RuntimeStatus {
46        running: bool,
47        active_loops: Vec<String>,
48    },
49    /// A tweet was discovered and scored by the discovery loop.
50    TweetDiscovered {
51        tweet_id: String,
52        author: String,
53        score: f64,
54        timestamp: String,
55    },
56    /// An action was skipped (rate limited, below threshold, safety filter).
57    ActionSkipped {
58        action_type: String,
59        reason: String,
60        timestamp: String,
61    },
62    /// A new content item was scheduled via the composer.
63    ContentScheduled {
64        id: i64,
65        content_type: String,
66        scheduled_for: Option<String>,
67    },
68    /// An error occurred.
69    Error { message: String },
70}
71
72/// Query parameters for WebSocket authentication.
73#[derive(Deserialize)]
74pub struct WsQuery {
75    /// API token passed as a query parameter.
76    pub token: String,
77}
78
79/// `GET /api/ws?token=...` — WebSocket upgrade with token auth.
80pub async fn ws_handler(
81    ws: WebSocketUpgrade,
82    State(state): State<Arc<AppState>>,
83    Query(params): Query<WsQuery>,
84) -> Response {
85    // Authenticate via query parameter.
86    if params.token != state.api_token {
87        return (
88            StatusCode::UNAUTHORIZED,
89            axum::Json(json!({"error": "unauthorized"})),
90        )
91            .into_response();
92    }
93
94    ws.on_upgrade(move |socket| handle_ws(socket, state))
95}
96
97/// Handle a single WebSocket connection.
98///
99/// Subscribes to the broadcast channel and forwards events as JSON text frames.
100async fn handle_ws(mut socket: WebSocket, state: Arc<AppState>) {
101    let mut rx = state.event_tx.subscribe();
102
103    loop {
104        match rx.recv().await {
105            Ok(event) => {
106                let json = match serde_json::to_string(&event) {
107                    Ok(j) => j,
108                    Err(e) => {
109                        tracing::error!(error = %e, "Failed to serialize WsEvent");
110                        continue;
111                    }
112                };
113                if socket.send(Message::Text(json.into())).await.is_err() {
114                    // Client disconnected.
115                    break;
116                }
117            }
118            Err(tokio::sync::broadcast::error::RecvError::Lagged(count)) => {
119                tracing::warn!(count, "WebSocket client lagged, events dropped");
120                let error_event = WsEvent::Error {
121                    message: format!("{count} events dropped due to slow consumer"),
122                };
123                if let Ok(json) = serde_json::to_string(&error_event) {
124                    if socket.send(Message::Text(json.into())).await.is_err() {
125                        break;
126                    }
127                }
128            }
129            Err(tokio::sync::broadcast::error::RecvError::Closed) => {
130                break;
131            }
132        }
133    }
134}