Skip to main content

tuitbot_server/
ws.rs

1//! WebSocket hub for real-time event streaming.
2//!
3//! Provides a `/api/ws` endpoint that streams server events to dashboard clients
4//! via a `tokio::sync::broadcast` channel.
5
6use std::sync::Arc;
7
8use axum::extract::ws::{Message, WebSocket};
9use axum::extract::{Query, State, WebSocketUpgrade};
10use axum::http::StatusCode;
11use axum::response::{IntoResponse, Response};
12use serde::{Deserialize, Serialize};
13use serde_json::json;
14
15use crate::state::AppState;
16
17/// Events pushed to WebSocket clients.
18#[derive(Clone, Debug, Serialize, Deserialize)]
19#[serde(tag = "type")]
20pub enum WsEvent {
21    /// An automation action was performed (reply, tweet, thread, etc.).
22    ActionPerformed {
23        action_type: String,
24        target: String,
25        content: String,
26        timestamp: String,
27    },
28    /// A new item was queued for approval.
29    ApprovalQueued {
30        id: i64,
31        action_type: String,
32        content: String,
33        #[serde(default)]
34        media_paths: Vec<String>,
35    },
36    /// An approval item's status was updated (approved, rejected, edited).
37    ApprovalUpdated {
38        id: i64,
39        status: String,
40        action_type: String,
41        #[serde(skip_serializing_if = "Option::is_none")]
42        actor: Option<String>,
43    },
44    /// Follower count changed.
45    FollowerUpdate { count: i64, change: i64 },
46    /// Automation runtime status changed.
47    RuntimeStatus {
48        running: bool,
49        active_loops: Vec<String>,
50    },
51    /// A tweet was discovered and scored by the discovery loop.
52    TweetDiscovered {
53        tweet_id: String,
54        author: String,
55        score: f64,
56        timestamp: String,
57    },
58    /// An action was skipped (rate limited, below threshold, safety filter).
59    ActionSkipped {
60        action_type: String,
61        reason: String,
62        timestamp: String,
63    },
64    /// A new content item was scheduled via the composer.
65    ContentScheduled {
66        id: i64,
67        content_type: String,
68        scheduled_for: Option<String>,
69    },
70    /// Circuit breaker state changed.
71    CircuitBreakerTripped {
72        state: String,
73        error_count: u32,
74        cooldown_remaining_seconds: u64,
75        timestamp: String,
76    },
77    /// An error occurred.
78    Error { message: String },
79}
80
81/// Query parameters for WebSocket authentication.
82#[derive(Deserialize)]
83pub struct WsQuery {
84    /// API token passed as a query parameter.
85    pub token: String,
86}
87
88/// `GET /api/ws?token=...` — WebSocket upgrade with token auth.
89pub async fn ws_handler(
90    ws: WebSocketUpgrade,
91    State(state): State<Arc<AppState>>,
92    Query(params): Query<WsQuery>,
93) -> Response {
94    // Authenticate via query parameter.
95    if params.token != state.api_token {
96        return (
97            StatusCode::UNAUTHORIZED,
98            axum::Json(json!({"error": "unauthorized"})),
99        )
100            .into_response();
101    }
102
103    ws.on_upgrade(move |socket| handle_ws(socket, state))
104}
105
106/// Handle a single WebSocket connection.
107///
108/// Subscribes to the broadcast channel and forwards events as JSON text frames.
109async fn handle_ws(mut socket: WebSocket, state: Arc<AppState>) {
110    let mut rx = state.event_tx.subscribe();
111
112    loop {
113        match rx.recv().await {
114            Ok(event) => {
115                let json = match serde_json::to_string(&event) {
116                    Ok(j) => j,
117                    Err(e) => {
118                        tracing::error!(error = %e, "Failed to serialize WsEvent");
119                        continue;
120                    }
121                };
122                if socket.send(Message::Text(json.into())).await.is_err() {
123                    // Client disconnected.
124                    break;
125                }
126            }
127            Err(tokio::sync::broadcast::error::RecvError::Lagged(count)) => {
128                tracing::warn!(count, "WebSocket client lagged, events dropped");
129                let error_event = WsEvent::Error {
130                    message: format!("{count} events dropped due to slow consumer"),
131                };
132                if let Ok(json) = serde_json::to_string(&error_event) {
133                    if socket.send(Message::Text(json.into())).await.is_err() {
134                        break;
135                    }
136                }
137            }
138            Err(tokio::sync::broadcast::error::RecvError::Closed) => {
139                break;
140            }
141        }
142    }
143}