Skip to main content

tuitbot_server/
ws.rs

1//! WebSocket hub for real-time event streaming.
2//!
3//! Provides a `/api/ws` endpoint that streams server events to dashboard clients
4//! via a `tokio::sync::broadcast` channel.
5//!
6//! Supports two authentication methods:
7//! - Query parameter: `?token=<api_token>` (Tauri/API clients)
8//! - Session cookie: `tuitbot_session=<token>` (web/LAN clients)
9
10use std::sync::Arc;
11
12use axum::extract::ws::{Message, WebSocket};
13use axum::extract::{Query, State, WebSocketUpgrade};
14use axum::http::{HeaderMap, StatusCode};
15use axum::response::{IntoResponse, Response};
16use serde::{Deserialize, Serialize};
17use serde_json::json;
18use tuitbot_core::auth::session;
19
20use crate::state::AppState;
21
22/// Events pushed to WebSocket clients.
23#[derive(Clone, Debug, Serialize, Deserialize)]
24#[serde(tag = "type")]
25pub enum WsEvent {
26    /// An automation action was performed (reply, tweet, thread, etc.).
27    ActionPerformed {
28        action_type: String,
29        target: String,
30        content: String,
31        timestamp: String,
32    },
33    /// A new item was queued for approval.
34    ApprovalQueued {
35        id: i64,
36        action_type: String,
37        content: String,
38        #[serde(default)]
39        media_paths: Vec<String>,
40    },
41    /// An approval item's status was updated (approved, rejected, edited).
42    ApprovalUpdated {
43        id: i64,
44        status: String,
45        action_type: String,
46        #[serde(skip_serializing_if = "Option::is_none")]
47        actor: Option<String>,
48    },
49    /// Follower count changed.
50    FollowerUpdate { count: i64, change: i64 },
51    /// Automation runtime status changed.
52    RuntimeStatus {
53        running: bool,
54        active_loops: Vec<String>,
55    },
56    /// A tweet was discovered and scored by the discovery loop.
57    TweetDiscovered {
58        tweet_id: String,
59        author: String,
60        score: f64,
61        timestamp: String,
62    },
63    /// An action was skipped (rate limited, below threshold, safety filter).
64    ActionSkipped {
65        action_type: String,
66        reason: String,
67        timestamp: String,
68    },
69    /// A new content item was scheduled via the composer.
70    ContentScheduled {
71        id: i64,
72        content_type: String,
73        scheduled_for: Option<String>,
74    },
75    /// Circuit breaker state changed.
76    CircuitBreakerTripped {
77        state: String,
78        error_count: u32,
79        cooldown_remaining_seconds: u64,
80        timestamp: String,
81    },
82    /// An error occurred.
83    Error { message: String },
84}
85
86/// Query parameters for WebSocket authentication.
87#[derive(Deserialize)]
88pub struct WsQuery {
89    /// API token passed as a query parameter (optional — cookie auth is fallback).
90    pub token: Option<String>,
91}
92
93/// Extract the session cookie value from headers.
94fn extract_session_cookie(headers: &HeaderMap) -> Option<String> {
95    headers
96        .get("cookie")
97        .and_then(|v| v.to_str().ok())
98        .and_then(|cookies| {
99            cookies.split(';').find_map(|c| {
100                let c = c.trim();
101                c.strip_prefix("tuitbot_session=").map(|v| v.to_string())
102            })
103        })
104}
105
106/// `GET /api/ws` — WebSocket upgrade with token or cookie auth.
107pub async fn ws_handler(
108    ws: WebSocketUpgrade,
109    State(state): State<Arc<AppState>>,
110    headers: HeaderMap,
111    Query(params): Query<WsQuery>,
112) -> Response {
113    // Strategy 1: Bearer token via query parameter
114    if let Some(ref token) = params.token {
115        if token == &state.api_token {
116            return ws.on_upgrade(move |socket| handle_ws(socket, state));
117        }
118    }
119
120    // Strategy 2: Session cookie
121    if let Some(session_token) = extract_session_cookie(&headers) {
122        if let Ok(Some(_)) = session::validate_session(&state.db, &session_token).await {
123            return ws.on_upgrade(move |socket| handle_ws(socket, state));
124        }
125    }
126
127    (
128        StatusCode::UNAUTHORIZED,
129        axum::Json(json!({"error": "unauthorized"})),
130    )
131        .into_response()
132}
133
134/// Handle a single WebSocket connection.
135///
136/// Subscribes to the broadcast channel and forwards events as JSON text frames.
137async fn handle_ws(mut socket: WebSocket, state: Arc<AppState>) {
138    let mut rx = state.event_tx.subscribe();
139
140    loop {
141        match rx.recv().await {
142            Ok(event) => {
143                let json = match serde_json::to_string(&event) {
144                    Ok(j) => j,
145                    Err(e) => {
146                        tracing::error!(error = %e, "Failed to serialize WsEvent");
147                        continue;
148                    }
149                };
150                if socket.send(Message::Text(json.into())).await.is_err() {
151                    // Client disconnected.
152                    break;
153                }
154            }
155            Err(tokio::sync::broadcast::error::RecvError::Lagged(count)) => {
156                tracing::warn!(count, "WebSocket client lagged, events dropped");
157                let error_event = WsEvent::Error {
158                    message: format!("{count} events dropped due to slow consumer"),
159                };
160                if let Ok(json) = serde_json::to_string(&error_event) {
161                    if socket.send(Message::Text(json.into())).await.is_err() {
162                        break;
163                    }
164                }
165            }
166            Err(tokio::sync::broadcast::error::RecvError::Closed) => {
167                break;
168            }
169        }
170    }
171}