Skip to main content

tuitbot_server/
ws.rs

1//! WebSocket hub for real-time event streaming.
2//!
3//! Provides a `/api/ws` endpoint that streams server events to dashboard clients
4//! via a `tokio::sync::broadcast` channel.
5//!
6//! Supports two authentication methods:
7//! - Query parameter: `?token=<api_token>` (Tauri/API clients)
8//! - Session cookie: `tuitbot_session=<token>` (web/LAN clients)
9
10use std::sync::Arc;
11
12use axum::extract::ws::{Message, WebSocket};
13use axum::extract::{Query, State, WebSocketUpgrade};
14use axum::http::{HeaderMap, StatusCode};
15use axum::response::{IntoResponse, Response};
16use serde::{Deserialize, Serialize};
17use serde_json::json;
18use tuitbot_core::auth::session;
19
20use crate::state::AppState;
21
22/// Wrapper that tags every [`WsEvent`] with the originating account.
23///
24/// Serializes flat thanks to `#[serde(flatten)]`, so the JSON looks like:
25/// `{ "account_id": "...", "type": "ApprovalQueued", ... }`
26#[derive(Clone, Debug, Serialize, Deserialize)]
27pub struct AccountWsEvent {
28    pub account_id: String,
29    #[serde(flatten)]
30    pub event: WsEvent,
31}
32
33/// Events pushed to WebSocket clients.
34#[derive(Clone, Debug, Serialize, Deserialize)]
35#[serde(tag = "type")]
36pub enum WsEvent {
37    /// An automation action was performed (reply, tweet, thread, etc.).
38    ActionPerformed {
39        action_type: String,
40        target: String,
41        content: String,
42        timestamp: String,
43    },
44    /// A new item was queued for approval.
45    ApprovalQueued {
46        id: i64,
47        action_type: String,
48        content: String,
49        #[serde(default)]
50        media_paths: Vec<String>,
51    },
52    /// An approval item's status was updated (approved, rejected, edited).
53    ApprovalUpdated {
54        id: i64,
55        status: String,
56        action_type: String,
57        #[serde(skip_serializing_if = "Option::is_none")]
58        actor: Option<String>,
59    },
60    /// Follower count changed.
61    FollowerUpdate { count: i64, change: i64 },
62    /// Automation runtime status changed.
63    RuntimeStatus {
64        running: bool,
65        active_loops: Vec<String>,
66    },
67    /// A tweet was discovered and scored by the discovery loop.
68    TweetDiscovered {
69        tweet_id: String,
70        author: String,
71        score: f64,
72        timestamp: String,
73    },
74    /// An action was skipped (rate limited, below threshold, safety filter).
75    ActionSkipped {
76        action_type: String,
77        reason: String,
78        timestamp: String,
79    },
80    /// A new content item was scheduled via the composer.
81    ContentScheduled {
82        id: i64,
83        content_type: String,
84        scheduled_for: Option<String>,
85    },
86    /// Circuit breaker state changed.
87    CircuitBreakerTripped {
88        state: String,
89        error_count: u32,
90        cooldown_remaining_seconds: u64,
91        timestamp: String,
92    },
93    /// An error occurred.
94    Error { message: String },
95}
96
97/// Query parameters for WebSocket authentication.
98#[derive(Deserialize)]
99pub struct WsQuery {
100    /// API token passed as a query parameter (optional — cookie auth is fallback).
101    pub token: Option<String>,
102}
103
104/// Extract the session cookie value from headers.
105fn extract_session_cookie(headers: &HeaderMap) -> Option<String> {
106    headers
107        .get("cookie")
108        .and_then(|v| v.to_str().ok())
109        .and_then(|cookies| {
110            cookies.split(';').find_map(|c| {
111                let c = c.trim();
112                c.strip_prefix("tuitbot_session=").map(|v| v.to_string())
113            })
114        })
115}
116
117/// `GET /api/ws` — WebSocket upgrade with token or cookie auth.
118pub async fn ws_handler(
119    ws: WebSocketUpgrade,
120    State(state): State<Arc<AppState>>,
121    headers: HeaderMap,
122    Query(params): Query<WsQuery>,
123) -> Response {
124    // Strategy 1: Bearer token via query parameter
125    if let Some(ref token) = params.token {
126        if token == &state.api_token {
127            return ws.on_upgrade(move |socket| handle_ws(socket, state));
128        }
129    }
130
131    // Strategy 2: Session cookie
132    if let Some(session_token) = extract_session_cookie(&headers) {
133        if let Ok(Some(_)) = session::validate_session(&state.db, &session_token).await {
134            return ws.on_upgrade(move |socket| handle_ws(socket, state));
135        }
136    }
137
138    (
139        StatusCode::UNAUTHORIZED,
140        axum::Json(json!({"error": "unauthorized"})),
141    )
142        .into_response()
143}
144
145/// Handle a single WebSocket connection.
146///
147/// Subscribes to the broadcast channel and forwards events as JSON text frames.
148async fn handle_ws(mut socket: WebSocket, state: Arc<AppState>) {
149    let mut rx = state.event_tx.subscribe();
150
151    loop {
152        match rx.recv().await {
153            Ok(event) => {
154                let json = match serde_json::to_string(&event) {
155                    Ok(j) => j,
156                    Err(e) => {
157                        tracing::error!(error = %e, "Failed to serialize WsEvent");
158                        continue;
159                    }
160                };
161                if socket.send(Message::Text(json.into())).await.is_err() {
162                    // Client disconnected.
163                    break;
164                }
165            }
166            Err(tokio::sync::broadcast::error::RecvError::Lagged(count)) => {
167                tracing::warn!(count, "WebSocket client lagged, events dropped");
168                let error_event = AccountWsEvent {
169                    account_id: String::new(),
170                    event: WsEvent::Error {
171                        message: format!("{count} events dropped due to slow consumer"),
172                    },
173                };
174                if let Ok(json) = serde_json::to_string(&error_event) {
175                    if socket.send(Message::Text(json.into())).await.is_err() {
176                        break;
177                    }
178                }
179            }
180            Err(tokio::sync::broadcast::error::RecvError::Closed) => {
181                break;
182            }
183        }
184    }
185}