Skip to main content

studio_worker/ws/
client.rs

1//! `tokio-tungstenite`-backed client for the studio WS worker channel.
2//!
3//! Responsibilities:
4//!  - coerce `http(s)://` API URLs to `ws(s)://` and append `/connect`
5//!  - attach `Authorization: Bearer <token>` and the
6//!    `studio-worker-v1` sub-protocol header to the upgrade
7//!  - map 401 upgrade responses + 4001 close codes to a typed
8//!    `WsClientError::AuthFailed` so the runtime can surface a
9//!    friendly hint
10//!  - serialise `WorkerInbound` to JSON text frames and parse
11//!    `WorkerOutbound` from incoming frames
12//!  - clean shutdown via `WsClient::close()`
13use std::convert::TryFrom;
14use std::time::Duration;
15
16use std::sync::Arc;
17
18use futures_util::stream::{SplitSink, SplitStream};
19use futures_util::{SinkExt, StreamExt};
20use tokio::net::TcpStream;
21use tokio::sync::Mutex;
22use tokio_tungstenite::tungstenite::client::IntoClientRequest;
23use tokio_tungstenite::tungstenite::http::HeaderValue;
24use tokio_tungstenite::tungstenite::http::StatusCode;
25use tokio_tungstenite::tungstenite::protocol::{frame::coding::CloseCode, CloseFrame};
26use tokio_tungstenite::tungstenite::{Error as TError, Message};
27use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
28use url::Url;
29
30use crate::ws::types::{WorkerInbound, WorkerOutbound};
31
32pub const SUBPROTOCOL: &str = "studio-worker-v1";
33/// Mirrors the same prefix the HTTP `ApiClient` mounts under.  Stays
34/// single-sourced with the API's Hono `basePath('/api')` + outer
35/// `/graphics` mount.
36const API_PREFIX: &str = "/graphics/api";
37
38/// Result wrapper for WS-client operations.
39pub type WsResult<T> = Result<T, WsClientError>;
40
41/// Errors surfaced by the client.  All variants carry just enough
42/// context to log a useful warning + to drive the reconnect policy.
43#[derive(Debug, thiserror::Error)]
44pub enum WsClientError {
45    /// Upgrade returned 401 or the server closed with 4001.
46    #[error("auth failed: {reason}")]
47    AuthFailed { reason: String },
48
49    /// Server closed for a reason other than auth failure.  The runtime
50    /// treats this as a transient drop and tries to reconnect.
51    #[error("connection closed by server")]
52    ConnectionClosed,
53
54    /// Anything else (DNS, TLS, timeout).
55    #[error("ws transport error: {0}")]
56    Transport(String),
57
58    /// Frame couldn't be parsed as JSON `WorkerOutbound`.
59    #[error("protocol error: {0}")]
60    Protocol(String),
61}
62
63impl From<TError> for WsClientError {
64    fn from(value: TError) -> Self {
65        match value {
66            TError::Http(response) if response.status() == StatusCode::UNAUTHORIZED => {
67                WsClientError::AuthFailed {
68                    reason: "401 on websocket upgrade".to_string(),
69                }
70            }
71            TError::ConnectionClosed | TError::AlreadyClosed => WsClientError::ConnectionClosed,
72            other => WsClientError::Transport(other.to_string()),
73        }
74    }
75}
76
77/// Coerce an `http://...api` base URL to the WS URL the server expects.
78fn build_connect_url(base_url: &str, worker_id: &str) -> WsResult<Url> {
79    let mut url = Url::parse(base_url)
80        .map_err(|e| WsClientError::Transport(format!("invalid base url: {e}")))?;
81    let new_scheme = match url.scheme() {
82        "http" => Some("ws"),
83        "https" => Some("wss"),
84        "ws" | "wss" => None, // already in WS form
85        other => {
86            return Err(WsClientError::Transport(format!(
87                "unsupported scheme: {other}"
88            )))
89        }
90    };
91    if let Some(scheme) = new_scheme {
92        url.set_scheme(scheme)
93            .map_err(|_| WsClientError::Transport("set_scheme failed".to_string()))?;
94    }
95    let trimmed_path = url.path().trim_end_matches('/');
96    // Append the studio's `/graphics/api` prefix unless the caller has
97    // already baked it into `base_url` (matches what `ApiClient::url`
98    // does on the HTTP side).
99    let prefixed = if trimmed_path.ends_with(API_PREFIX) {
100        trimmed_path.to_string()
101    } else {
102        format!("{trimmed_path}{API_PREFIX}")
103    };
104    let new_path = format!("{prefixed}/workers/{worker_id}/connect");
105    url.set_path(&new_path);
106    Ok(url)
107}
108
109/// Establish the WebSocket session.  Sends the upgrade with the bearer
110/// token + sub-protocol header and returns a ready-to-use client.
111pub async fn connect(base_url: &str, worker_id: &str, auth_token: &str) -> WsResult<WsClient> {
112    let url = build_connect_url(base_url, worker_id)?;
113    let mut request = url
114        .as_str()
115        .into_client_request()
116        .map_err(WsClientError::from)?;
117    let headers = request.headers_mut();
118    headers.insert(
119        "Authorization",
120        HeaderValue::try_from(format!("Bearer {auth_token}"))
121            .map_err(|e| WsClientError::Transport(format!("invalid auth header: {e}")))?,
122    );
123    headers.insert(
124        "Sec-WebSocket-Protocol",
125        HeaderValue::from_static(SUBPROTOCOL),
126    );
127
128    let (stream, _response) = tokio_tungstenite::connect_async(request).await?;
129    let (sink, source) = stream.split();
130    Ok(WsClient {
131        sink,
132        source,
133        closed: false,
134    })
135}
136
137type WsSink = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
138type WsSource = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
139
140/// Active worker-side WS session.  Cheap to construct, expensive to
141/// drop (closes the socket gracefully).
142#[allow(missing_debug_implementations)]
143pub struct WsClient {
144    sink: WsSink,
145    source: WsSource,
146    closed: bool,
147}
148
149impl WsClient {
150    /// Split the client into a cheap-to-clone `WsSender` and a
151    /// single-owner `WsReceiver`.  Used by the runtime so heartbeat,
152    /// log-shipper, and engine-dispatch tasks can all push frames
153    /// concurrently while a dedicated task drains the receive side.
154    pub fn split(self) -> (WsSender, WsReceiver) {
155        let sink = Arc::new(Mutex::new(self.sink));
156        (
157            WsSender { sink },
158            WsReceiver {
159                source: self.source,
160                closed: false,
161            },
162        )
163    }
164}
165
166/// Cheap-to-clone send half.  All senders share one `Mutex` over the
167/// underlying sink so writes from heartbeat / log-shipper / engine
168/// dispatch tasks are serialised correctly.
169#[derive(Clone)]
170#[allow(missing_debug_implementations)]
171pub struct WsSender {
172    sink: Arc<Mutex<WsSink>>,
173}
174
175impl WsSender {
176    pub async fn send(&self, frame: &WorkerInbound) -> WsResult<()> {
177        let text =
178            serde_json::to_string(frame).map_err(|e| WsClientError::Protocol(e.to_string()))?;
179        let mut guard = self.sink.lock().await;
180        guard
181            .send(Message::Text(text.into()))
182            .await
183            .map_err(WsClientError::from)
184    }
185
186    pub async fn close(&self, code: u16, reason: &str) -> WsResult<()> {
187        let frame = CloseFrame {
188            code: CloseCode::from(code),
189            reason: reason.to_owned().into(),
190        };
191        let mut guard = self.sink.lock().await;
192        let _ = tokio::time::timeout(
193            Duration::from_secs(5),
194            guard.send(Message::Close(Some(frame))),
195        )
196        .await;
197        Ok(())
198    }
199}
200
201/// Single-owner receive half.  Owned by the session's reader task.
202#[allow(missing_debug_implementations)]
203pub struct WsReceiver {
204    source: WsSource,
205    closed: bool,
206}
207
208impl WsReceiver {
209    /// Read the next outbound frame.  Same semantics as
210    /// `WsClient::recv` — silent close → `Ok(None)`, close frame with
211    /// 4001 → `AuthFailed`, other closes → `ConnectionClosed`.
212    pub async fn recv(&mut self) -> WsResult<Option<WorkerOutbound>> {
213        if self.closed {
214            return Ok(None);
215        }
216        while let Some(item) = self.source.next().await {
217            match item {
218                Ok(Message::Text(text)) => {
219                    let frame: WorkerOutbound = serde_json::from_str(&text)
220                        .map_err(|e| WsClientError::Protocol(e.to_string()))?;
221                    return Ok(Some(frame));
222                }
223                Ok(Message::Binary(_)) => {
224                    return Err(WsClientError::Protocol(
225                        "unexpected binary frame".to_string(),
226                    ));
227                }
228                Ok(Message::Close(frame)) => {
229                    self.closed = true;
230                    return Err(close_frame_to_error(frame));
231                }
232                Ok(_) => continue,
233                Err(e) => return Err(WsClientError::from(e)),
234            }
235        }
236        self.closed = true;
237        Ok(None)
238    }
239}
240
241impl std::fmt::Debug for WsClient {
242    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
243        f.debug_struct("WsClient")
244            .field("closed", &self.closed)
245            .finish()
246    }
247}
248
249impl WsClient {
250    /// Send a typed inbound frame as a JSON text frame.
251    pub async fn send(&mut self, frame: &WorkerInbound) -> WsResult<()> {
252        let text =
253            serde_json::to_string(frame).map_err(|e| WsClientError::Protocol(e.to_string()))?;
254        self.sink
255            .send(Message::Text(text.into()))
256            .await
257            .map_err(WsClientError::from)
258    }
259
260    /// Receive the next typed outbound frame.  Returns `Ok(None)` on
261    /// a clean close (no error frame), `Err` on auth or transport
262    /// failures, or `Ok(Some(frame))` for normal traffic.  Pings and
263    /// other control frames are swallowed silently.
264    pub async fn recv(&mut self) -> WsResult<Option<WorkerOutbound>> {
265        if self.closed {
266            return Ok(None);
267        }
268        while let Some(item) = self.source.next().await {
269            match item {
270                Ok(Message::Text(text)) => {
271                    let frame: WorkerOutbound = serde_json::from_str(&text)
272                        .map_err(|e| WsClientError::Protocol(e.to_string()))?;
273                    return Ok(Some(frame));
274                }
275                Ok(Message::Binary(_)) => {
276                    return Err(WsClientError::Protocol(
277                        "unexpected binary frame".to_string(),
278                    ));
279                }
280                Ok(Message::Close(frame)) => {
281                    self.closed = true;
282                    return Err(close_frame_to_error(frame));
283                }
284                Ok(_) => continue, // ping/pong/empty — keep reading
285                Err(e) => return Err(WsClientError::from(e)),
286            }
287        }
288        self.closed = true;
289        Ok(None)
290    }
291
292    /// Best-effort graceful close.  Idempotent.
293    pub async fn close(&mut self, code: u16, reason: &str) -> WsResult<()> {
294        if self.closed {
295            return Ok(());
296        }
297        self.closed = true;
298        let frame = CloseFrame {
299            code: CloseCode::from(code),
300            reason: reason.to_owned().into(),
301        };
302        // Wrap in a short timeout so a stuck peer can't hang shutdown.
303        let _ = tokio::time::timeout(
304            Duration::from_secs(5),
305            self.sink.send(Message::Close(Some(frame))),
306        )
307        .await;
308        Ok(())
309    }
310}
311
312fn close_frame_to_error(frame: Option<CloseFrame>) -> WsClientError {
313    if let Some(frame) = frame {
314        let code: u16 = frame.code.into();
315        if code == 4001 {
316            return WsClientError::AuthFailed {
317                reason: format!("server closed 4001: {}", frame.reason),
318            };
319        }
320    }
321    WsClientError::ConnectionClosed
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327
328    #[test]
329    fn build_connect_url_http_to_ws() {
330        let url = build_connect_url("http://api.example/graphics/api", "w-1").unwrap();
331        assert_eq!(url.scheme(), "ws");
332        assert!(url.path().ends_with("/workers/w-1/connect"));
333    }
334
335    #[test]
336    fn build_connect_url_https_to_wss() {
337        let url = build_connect_url("https://api.example/graphics/api/", "w-2").unwrap();
338        assert_eq!(url.scheme(), "wss");
339        assert_eq!(url.path(), "/graphics/api/workers/w-2/connect");
340    }
341
342    #[test]
343    fn build_connect_url_appends_graphics_api_prefix_when_missing() {
344        let url = build_connect_url("http://localhost:9790", "w-3").unwrap();
345        assert_eq!(url.scheme(), "ws");
346        assert_eq!(url.path(), "/graphics/api/workers/w-3/connect");
347    }
348
349    #[test]
350    fn build_connect_url_preserves_existing_ws_scheme() {
351        let url = build_connect_url("ws://localhost:9790/x", "w").unwrap();
352        assert_eq!(url.scheme(), "ws");
353    }
354
355    #[test]
356    fn build_connect_url_rejects_unknown_scheme() {
357        let err = build_connect_url("ftp://nope/", "w").unwrap_err();
358        assert!(matches!(err, WsClientError::Transport(_)));
359    }
360
361    #[test]
362    fn build_connect_url_rejects_invalid_url() {
363        let err = build_connect_url("not a url", "w").unwrap_err();
364        assert!(matches!(err, WsClientError::Transport(_)));
365    }
366
367    #[test]
368    fn close_frame_4001_maps_to_auth_failed() {
369        let frame = CloseFrame {
370            code: CloseCode::Library(4001),
371            reason: "bad token".into(),
372        };
373        let err = close_frame_to_error(Some(frame));
374        assert!(matches!(err, WsClientError::AuthFailed { .. }));
375    }
376
377    #[test]
378    fn close_frame_other_codes_map_to_connection_closed() {
379        let frame = CloseFrame {
380            code: CloseCode::Normal,
381            reason: "bye".into(),
382        };
383        let err = close_frame_to_error(Some(frame));
384        assert!(matches!(err, WsClientError::ConnectionClosed));
385    }
386
387    #[test]
388    fn close_frame_missing_maps_to_connection_closed() {
389        let err = close_frame_to_error(None);
390        assert!(matches!(err, WsClientError::ConnectionClosed));
391    }
392
393    #[test]
394    fn transport_error_round_trips_through_from_impl() {
395        let inner = TError::AlreadyClosed;
396        let mapped: WsClientError = inner.into();
397        assert!(matches!(mapped, WsClientError::ConnectionClosed));
398    }
399}