Skip to main content

studio_worker/ws/
client.rs

1//! `tokio-tungstenite`-backed client for the studio WS worker channel.
2//!
3//! Responsibilities:
4//!  - coerce `http(s)://` API URLs to `ws(s)://` and append `/connect`
5//!  - attach `Authorization: Bearer <token>` and the
6//!    `studio-worker-v1` sub-protocol header to the upgrade
7//!  - map 401 upgrade responses + 4001 close codes to a typed
8//!    `WsClientError::AuthFailed` so the runtime can surface a
9//!    friendly hint
10//!  - serialise `WorkerInbound` to JSON text frames and parse
11//!    `WorkerOutbound` from incoming frames
12//!  - clean shutdown via `WsClient::close()`
13//!  - emit structured `tracing` breadcrumbs (target
14//!    `studio_worker::ws::client`) at the transport boundary so
15//!    connect / recv / send failures are never silent.  The session
16//!    discards recv errors in its generic `Disconnected(_)` arm and
17//!    fires `let _ = sender.send(...)` for accept / reject / fail /
18//!    completeJson, so this layer is the only place those faults can
19//!    surface.  Mirrors the `studio_worker::http` breadcrumb contract.
20use std::convert::TryFrom;
21use std::time::{Duration, Instant};
22
23use std::sync::Arc;
24
25use futures_util::stream::{SplitSink, SplitStream};
26use futures_util::{SinkExt, StreamExt};
27use tokio::net::TcpStream;
28use tokio::sync::Mutex;
29use tokio_tungstenite::tungstenite::client::IntoClientRequest;
30use tokio_tungstenite::tungstenite::http::HeaderValue;
31use tokio_tungstenite::tungstenite::http::Response;
32use tokio_tungstenite::tungstenite::http::StatusCode;
33use tokio_tungstenite::tungstenite::protocol::{frame::coding::CloseCode, CloseFrame};
34use tokio_tungstenite::tungstenite::{Error as TError, Message};
35use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
36use tracing::{debug, warn};
37use url::Url;
38
39use crate::ws::types::{WorkerInbound, WorkerOutbound};
40
41pub const SUBPROTOCOL: &str = "studio-worker-v1";
42
43/// Tracing target used for every event emitted by the WS client.
44/// Stable so operators can filter with
45/// `RUST_LOG=studio_worker::ws::client=debug` without enabling
46/// wire-level tungstenite logging.
47const TRACE_TARGET: &str = "studio_worker::ws::client";
48/// Mirrors the same prefix the HTTP `ApiClient` mounts under.  Stays
49/// single-sourced with the API's Hono `basePath('/api')` + outer
50/// `/graphics` mount.
51const API_PREFIX: &str = "/graphics/api";
52
53/// Upper bound on a single connect attempt (TCP + TLS + WS upgrade). Without it a peer that accepts
54/// the socket but stalls the upgrade hangs the reconnect loop forever (no logs, no progress) — the
55/// connect-side twin of the read-idle-timeout on an established session.
56const CONNECT_TIMEOUT: Duration = Duration::from_secs(15);
57
58/// Result wrapper for WS-client operations.
59pub type WsResult<T> = Result<T, WsClientError>;
60
61/// Errors surfaced by the client.  All variants carry just enough
62/// context to log a useful warning + to drive the reconnect policy.
63#[derive(Debug, thiserror::Error)]
64pub enum WsClientError {
65    /// Upgrade returned 401 or the server closed with 4001.
66    #[error("auth failed: {reason}")]
67    AuthFailed { reason: String },
68
69    /// Server closed for a reason other than auth failure.  The runtime
70    /// treats this as a transient drop and tries to reconnect.
71    #[error("connection closed by server")]
72    ConnectionClosed,
73
74    /// Anything else (DNS, TLS, timeout).
75    #[error("ws transport error: {0}")]
76    Transport(String),
77
78    /// Frame couldn't be parsed as JSON `WorkerOutbound`.
79    #[error("protocol error: {0}")]
80    Protocol(String),
81}
82
83impl From<TError> for WsClientError {
84    fn from(value: TError) -> Self {
85        match value {
86            TError::Http(response) if response.status() == StatusCode::UNAUTHORIZED => {
87                WsClientError::AuthFailed {
88                    reason: "401 on websocket upgrade".to_string(),
89                }
90            }
91            // Any other upgrade status (500/502/503/429 …) is a
92            // transient server-side fault the reconnect loop retries.
93            // Carry the status + body so the studio's error
94            // `reference` id reaches the operator's log.
95            TError::Http(response) => {
96                WsClientError::Transport(http_upgrade_error_message(&response))
97            }
98            TError::ConnectionClosed | TError::AlreadyClosed => WsClientError::ConnectionClosed,
99            other => WsClientError::Transport(other.to_string()),
100        }
101    }
102}
103
104/// Upper bound on the number of characters of a non-401 HTTP upgrade
105/// body we fold into the transport-error breadcrumb.  Enough to carry
106/// the studio's JSON error `reference` id without letting a stray HTML
107/// error page flood the log line.
108const HTTP_ERROR_BODY_MAX_CHARS: usize = 300;
109
110/// Render a non-401 HTTP upgrade failure into a transport-error string
111/// that surfaces both the status and (when present) the response body.
112/// tungstenite's own `Error::Http` Display keeps only the status, but
113/// the studio answers a failed `/connect` upgrade with a JSON body
114/// carrying an error `reference` id — the same value Sentry shows for
115/// the matching studio-side event.  Folding it into the breadcrumb
116/// lets an operator correlate the worker's reconnect-loop warning with
117/// the studio's logged failure.  The body is decoded lossily, trimmed,
118/// and clipped to [`HTTP_ERROR_BODY_MAX_CHARS`].
119fn http_upgrade_error_message(response: &Response<Option<Vec<u8>>>) -> String {
120    let status = response.status();
121    let body = response.body().as_deref().and_then(|bytes| {
122        let decoded = String::from_utf8_lossy(bytes);
123        let trimmed = decoded.trim();
124        if trimmed.is_empty() {
125            return None;
126        }
127        Some(clip_error_body(trimmed))
128    });
129    match body {
130        Some(b) => format!("HTTP {status} on websocket upgrade: {b}"),
131        None => format!("HTTP {status} on websocket upgrade"),
132    }
133}
134
135/// Clip a decoded error body to [`HTTP_ERROR_BODY_MAX_CHARS`],
136/// appending an ellipsis when truncated.  Char-based so a multibyte
137/// body can't be split mid-codepoint.
138fn clip_error_body(body: &str) -> String {
139    if body.chars().count() > HTTP_ERROR_BODY_MAX_CHARS {
140        let mut clipped: String = body.chars().take(HTTP_ERROR_BODY_MAX_CHARS).collect();
141        clipped.push('\u{2026}');
142        clipped
143    } else {
144        body.to_string()
145    }
146}
147
148/// Coerce an `http://...api` base URL to the WS URL the server expects.
149fn build_connect_url(base_url: &str, worker_id: &str) -> WsResult<Url> {
150    let mut url = Url::parse(base_url)
151        .map_err(|e| WsClientError::Transport(format!("invalid base url: {e}")))?;
152    let new_scheme = match url.scheme() {
153        "http" => Some("ws"),
154        "https" => Some("wss"),
155        "ws" | "wss" => None, // already in WS form
156        other => {
157            return Err(WsClientError::Transport(format!(
158                "unsupported scheme: {other}"
159            )))
160        }
161    };
162    if let Some(scheme) = new_scheme {
163        url.set_scheme(scheme)
164            .map_err(|_| WsClientError::Transport("set_scheme failed".to_string()))?;
165    }
166    let trimmed_path = url.path().trim_end_matches('/');
167    // Append the studio's `/graphics/api` prefix unless the caller has
168    // already baked it into `base_url` (matches what `ApiClient::url`
169    // does on the HTTP side).
170    let prefixed = if trimmed_path.ends_with(API_PREFIX) {
171        trimmed_path.to_string()
172    } else {
173        format!("{trimmed_path}{API_PREFIX}")
174    };
175    let new_path = format!("{prefixed}/workers/{worker_id}/connect");
176    url.set_path(&new_path);
177    Ok(url)
178}
179
180/// Establish the WebSocket session.  Sends the upgrade with the bearer
181/// token + sub-protocol header and returns a ready-to-use client.
182///
183/// Emits a `debug` breadcrumb on success and a `warn` on failure so a
184/// dead studio, bad DNS, or TLS fault is visible without the caller
185/// having to log it.
186pub async fn connect(base_url: &str, worker_id: &str, auth_token: &str) -> WsResult<WsClient> {
187    let started = Instant::now();
188    let result = connect_inner(base_url, worker_id, auth_token, CONNECT_TIMEOUT).await;
189    let elapsed_ms = started.elapsed().as_millis() as u64;
190    match &result {
191        Ok(_) => debug!(
192            target: TRACE_TARGET,
193            op = "connect",
194            worker_id,
195            elapsed_ms,
196            "websocket established"
197        ),
198        Err(e) => warn!(
199            target: TRACE_TARGET,
200            op = "connect",
201            worker_id,
202            elapsed_ms,
203            error = %e,
204            "websocket connect failed"
205        ),
206    }
207    result
208}
209
210async fn connect_inner(
211    base_url: &str,
212    worker_id: &str,
213    auth_token: &str,
214    connect_timeout: Duration,
215) -> WsResult<WsClient> {
216    let url = build_connect_url(base_url, worker_id)?;
217    debug!(
218        target: TRACE_TARGET,
219        op = "connect",
220        worker_id,
221        url = %url,
222        "opening websocket"
223    );
224    let mut request = url
225        .as_str()
226        .into_client_request()
227        .map_err(WsClientError::from)?;
228    let headers = request.headers_mut();
229    headers.insert(
230        "Authorization",
231        HeaderValue::try_from(format!("Bearer {auth_token}"))
232            .map_err(|e| WsClientError::Transport(format!("invalid auth header: {e}")))?,
233    );
234    headers.insert(
235        "Sec-WebSocket-Protocol",
236        HeaderValue::from_static(SUBPROTOCOL),
237    );
238
239    let (stream, _response) = match tokio::time::timeout(
240        connect_timeout,
241        tokio_tungstenite::connect_async(request),
242    )
243    .await
244    {
245        Ok(result) => result?,
246        Err(_elapsed) => {
247            return Err(WsClientError::Transport(format!(
248                "connect timed out after {connect_timeout:?}"
249            )))
250        }
251    };
252    let (sink, source) = stream.split();
253    Ok(WsClient {
254        sink,
255        source,
256        closed: false,
257    })
258}
259
260type WsSink = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
261type WsSource = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
262
263/// Active worker-side WS session.  Cheap to construct, expensive to
264/// drop (closes the socket gracefully).
265#[allow(missing_debug_implementations)]
266pub struct WsClient {
267    sink: WsSink,
268    source: WsSource,
269    closed: bool,
270}
271
272impl WsClient {
273    /// Split the client into a cheap-to-clone `WsSender` and a
274    /// single-owner `WsReceiver`.  Used by the runtime so heartbeat,
275    /// log-shipper, and engine-dispatch tasks can all push frames
276    /// concurrently while a dedicated task drains the receive side.
277    pub fn split(self) -> (WsSender, WsReceiver) {
278        let sink = Arc::new(Mutex::new(self.sink));
279        (
280            WsSender { sink },
281            WsReceiver {
282                source: self.source,
283                closed: false,
284            },
285        )
286    }
287}
288
289/// Cheap-to-clone send half.  All senders share one `Mutex` over the
290/// underlying sink so writes from heartbeat / log-shipper / engine
291/// dispatch tasks are serialised correctly.
292#[derive(Clone)]
293#[allow(missing_debug_implementations)]
294pub struct WsSender {
295    sink: Arc<Mutex<WsSink>>,
296}
297
298impl WsSender {
299    pub async fn send(&self, frame: &WorkerInbound) -> WsResult<()> {
300        let text = serialize_frame(frame)?;
301        let mut guard = self.sink.lock().await;
302        guard
303            .send(Message::Text(text.into()))
304            .await
305            .map_err(|e| map_send_failure(frame, e))
306    }
307
308    pub async fn close(&self, code: u16, reason: &str) -> WsResult<()> {
309        debug!(target: TRACE_TARGET, op = "close", code, reason, "closing websocket");
310        let frame = CloseFrame {
311            code: CloseCode::from(code),
312            reason: reason.to_owned().into(),
313        };
314        let mut guard = self.sink.lock().await;
315        if tokio::time::timeout(
316            Duration::from_secs(5),
317            guard.send(Message::Close(Some(frame))),
318        )
319        .await
320        .is_err()
321        {
322            warn!(target: TRACE_TARGET, op = "close", code, "timed out sending close frame");
323        }
324        Ok(())
325    }
326}
327
328/// Single-owner receive half.  Owned by the session's reader task.
329#[allow(missing_debug_implementations)]
330pub struct WsReceiver {
331    source: WsSource,
332    closed: bool,
333}
334
335impl WsReceiver {
336    /// Read the next outbound frame.  Same semantics as
337    /// `WsClient::recv` — silent close → `Ok(None)`, close frame with
338    /// 4001 → `AuthFailed`, other closes → `ConnectionClosed`.
339    pub async fn recv(&mut self) -> WsResult<Option<WorkerOutbound>> {
340        recv_next(&mut self.source, &mut self.closed).await
341    }
342}
343
344impl std::fmt::Debug for WsClient {
345    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
346        f.debug_struct("WsClient")
347            .field("closed", &self.closed)
348            .finish()
349    }
350}
351
352impl WsClient {
353    /// Send a typed inbound frame as a JSON text frame.
354    pub async fn send(&mut self, frame: &WorkerInbound) -> WsResult<()> {
355        let text = serialize_frame(frame)?;
356        self.sink
357            .send(Message::Text(text.into()))
358            .await
359            .map_err(|e| map_send_failure(frame, e))
360    }
361
362    /// Receive the next typed outbound frame.  Returns `Ok(None)` on
363    /// a clean close (no error frame), `Err` on auth or transport
364    /// failures, or `Ok(Some(frame))` for normal traffic.  Pings and
365    /// other control frames are swallowed silently.
366    pub async fn recv(&mut self) -> WsResult<Option<WorkerOutbound>> {
367        recv_next(&mut self.source, &mut self.closed).await
368    }
369
370    /// Best-effort graceful close.  Idempotent.
371    pub async fn close(&mut self, code: u16, reason: &str) -> WsResult<()> {
372        if self.closed {
373            return Ok(());
374        }
375        self.closed = true;
376        debug!(target: TRACE_TARGET, op = "close", code, reason, "closing websocket");
377        let frame = CloseFrame {
378            code: CloseCode::from(code),
379            reason: reason.to_owned().into(),
380        };
381        // Wrap in a short timeout so a stuck peer can't hang shutdown.
382        if tokio::time::timeout(
383            Duration::from_secs(5),
384            self.sink.send(Message::Close(Some(frame))),
385        )
386        .await
387        .is_err()
388        {
389            warn!(target: TRACE_TARGET, op = "close", code, "timed out sending close frame");
390        }
391        Ok(())
392    }
393}
394
395/// Human-readable label for an inbound frame, used in send-failure
396/// breadcrumbs so operators can tell a dropped `accept` from a dropped
397/// `heartbeat`.
398fn frame_label(frame: &WorkerInbound) -> &'static str {
399    match frame {
400        WorkerInbound::Hello(_) => "hello",
401        WorkerInbound::Heartbeat { .. } => "heartbeat",
402        WorkerInbound::Accept { .. } => "accept",
403        WorkerInbound::Reject { .. } => "reject",
404        WorkerInbound::CompleteJson { .. } => "completeJson",
405        WorkerInbound::Fail { .. } => "fail",
406        WorkerInbound::LogBatch { .. } => "logBatch",
407        WorkerInbound::ReadyForMore => "readyForMore",
408    }
409}
410
411/// Log a failed frame send.  Callers (the session) routinely fire
412/// `let _ = sender.send(...)`, so without this a dropped `accept` /
413/// `fail` / `completeJson` would vanish without trace.
414fn log_send_error(frame: &WorkerInbound, err: &WsClientError) {
415    warn!(
416        target: TRACE_TARGET,
417        op = "send",
418        frame = frame_label(frame),
419        error = %err,
420        "failed to send frame"
421    );
422}
423
424/// Serialise an inbound frame to its JSON wire form, mapping (and
425/// logging) a serialisation failure to a `Protocol` error.  Shared by
426/// `WsSender::send` and `WsClient::send` so the split and monolithic
427/// send paths can't drift in how they encode a frame or report a
428/// failure.
429fn serialize_frame(frame: &WorkerInbound) -> WsResult<String> {
430    serde_json::to_string(frame).map_err(|e| {
431        let err = WsClientError::Protocol(e.to_string());
432        log_send_error(frame, &err);
433        err
434    })
435}
436
437/// Map a sink-level send failure to a logged `WsClientError`.  Shared by
438/// both send paths so a dropped frame always leaves the same breadcrumb.
439fn map_send_failure(frame: &WorkerInbound, e: TError) -> WsClientError {
440    let err = WsClientError::from(e);
441    log_send_error(frame, &err);
442    err
443}
444
445/// The shared `recv` loop body for both the split `WsReceiver` and the
446/// monolithic `WsClient`.  Pulls the next application frame off
447/// `source`, routing every raw message through [`classify_incoming`] so
448/// error / close handling (and its logging) lives in exactly one place.
449/// Latches `*closed` once the stream ends or the server sends a close
450/// frame, matching `recv`'s idempotent-after-close contract.
451async fn recv_next(source: &mut WsSource, closed: &mut bool) -> WsResult<Option<WorkerOutbound>> {
452    if *closed {
453        return Ok(None);
454    }
455    while let Some(item) = source.next().await {
456        match classify_incoming(item) {
457            RecvStep::Yield(frame) => return Ok(Some(frame)),
458            RecvStep::Skip => continue,
459            RecvStep::Fail(e) => return Err(e),
460            RecvStep::Closed(e) => {
461                *closed = true;
462                return Err(e);
463            }
464        }
465    }
466    *closed = true;
467    debug!(target: TRACE_TARGET, op = "recv", "stream ended (no close frame)");
468    Ok(None)
469}
470
471/// Interpretation of a single raw WS message during `recv`.  Splitting
472/// this out routes every error / close through one logging site
473/// ([`classify_incoming`]); the loop scaffolding around it is shared via
474/// [`recv_next`], so the split and monolithic receive paths are one
475/// implementation.
476enum RecvStep {
477    /// Decoded application frame to hand back to the caller.
478    Yield(WorkerOutbound),
479    /// Control / empty frame (ping / pong) — keep reading.
480    Skip,
481    /// Error to surface without latching the receiver closed.
482    Fail(WsClientError),
483    /// Server sent a close frame — latch closed, then surface the error.
484    Closed(WsClientError),
485}
486
487/// Classify one incoming message, emitting a tracing breadcrumb for
488/// every failure / close so transport faults are never silent.
489fn classify_incoming(item: Result<Message, TError>) -> RecvStep {
490    match item {
491        Ok(Message::Text(text)) => match serde_json::from_str::<WorkerOutbound>(&text) {
492            Ok(frame) => RecvStep::Yield(frame),
493            Err(e) => {
494                warn!(
495                    target: TRACE_TARGET,
496                    op = "recv",
497                    error = %e,
498                    "dropping unparseable text frame"
499                );
500                RecvStep::Fail(WsClientError::Protocol(e.to_string()))
501            }
502        },
503        Ok(Message::Binary(_)) => {
504            warn!(
505                target: TRACE_TARGET,
506                op = "recv",
507                "rejecting unexpected binary frame"
508            );
509            RecvStep::Fail(WsClientError::Protocol(
510                "unexpected binary frame".to_string(),
511            ))
512        }
513        Ok(Message::Close(frame)) => {
514            let err = close_frame_to_error(frame);
515            match &err {
516                WsClientError::AuthFailed { reason } => warn!(
517                    target: TRACE_TARGET,
518                    op = "recv",
519                    reason = %reason,
520                    "server closed connection: auth failed"
521                ),
522                _ => debug!(
523                    target: TRACE_TARGET,
524                    op = "recv",
525                    "server closed connection"
526                ),
527            }
528            RecvStep::Closed(err)
529        }
530        // ping / pong / empty — keep reading.
531        Ok(_) => RecvStep::Skip,
532        Err(e) => {
533            let mapped = WsClientError::from(e);
534            match &mapped {
535                // A clean close surfaces here as ConnectionClosed on
536                // some transports; keep it at debug to avoid noise on
537                // expected reconnect churn.
538                WsClientError::ConnectionClosed => debug!(
539                    target: TRACE_TARGET,
540                    op = "recv",
541                    "connection closed by peer"
542                ),
543                other => warn!(
544                    target: TRACE_TARGET,
545                    op = "recv",
546                    error = %other,
547                    "transport error while reading frame"
548                ),
549            }
550            RecvStep::Fail(mapped)
551        }
552    }
553}
554
555fn close_frame_to_error(frame: Option<CloseFrame>) -> WsClientError {
556    if let Some(frame) = frame {
557        let code: u16 = frame.code.into();
558        if code == 4001 {
559            return WsClientError::AuthFailed {
560                reason: format!("server closed 4001: {}", frame.reason),
561            };
562        }
563    }
564    WsClientError::ConnectionClosed
565}
566
567#[cfg(test)]
568mod tests {
569    use super::*;
570
571    #[test]
572    fn build_connect_url_http_to_ws() {
573        let url = build_connect_url("http://api.example/graphics/api", "w-1").unwrap();
574        assert_eq!(url.scheme(), "ws");
575        assert!(url.path().ends_with("/workers/w-1/connect"));
576    }
577
578    #[test]
579    fn build_connect_url_https_to_wss() {
580        let url = build_connect_url("https://api.example/graphics/api/", "w-2").unwrap();
581        assert_eq!(url.scheme(), "wss");
582        assert_eq!(url.path(), "/graphics/api/workers/w-2/connect");
583    }
584
585    #[test]
586    fn build_connect_url_appends_graphics_api_prefix_when_missing() {
587        let url = build_connect_url("http://localhost:9790", "w-3").unwrap();
588        assert_eq!(url.scheme(), "ws");
589        assert_eq!(url.path(), "/graphics/api/workers/w-3/connect");
590    }
591
592    #[test]
593    fn build_connect_url_preserves_existing_ws_scheme() {
594        let url = build_connect_url("ws://localhost:9790/x", "w").unwrap();
595        assert_eq!(url.scheme(), "ws");
596    }
597
598    #[test]
599    fn build_connect_url_rejects_unknown_scheme() {
600        let err = build_connect_url("ftp://nope/", "w").unwrap_err();
601        assert!(matches!(err, WsClientError::Transport(_)));
602    }
603
604    #[test]
605    fn build_connect_url_rejects_invalid_url() {
606        let err = build_connect_url("not a url", "w").unwrap_err();
607        assert!(matches!(err, WsClientError::Transport(_)));
608    }
609
610    #[test]
611    fn close_frame_4001_maps_to_auth_failed() {
612        let frame = CloseFrame {
613            code: CloseCode::Library(4001),
614            reason: "bad token".into(),
615        };
616        let err = close_frame_to_error(Some(frame));
617        assert!(matches!(err, WsClientError::AuthFailed { .. }));
618    }
619
620    #[test]
621    fn close_frame_other_codes_map_to_connection_closed() {
622        let frame = CloseFrame {
623            code: CloseCode::Normal,
624            reason: "bye".into(),
625        };
626        let err = close_frame_to_error(Some(frame));
627        assert!(matches!(err, WsClientError::ConnectionClosed));
628    }
629
630    #[test]
631    fn close_frame_missing_maps_to_connection_closed() {
632        let err = close_frame_to_error(None);
633        assert!(matches!(err, WsClientError::ConnectionClosed));
634    }
635
636    #[test]
637    fn transport_error_round_trips_through_from_impl() {
638        let inner = TError::AlreadyClosed;
639        let mapped: WsClientError = inner.into();
640        assert!(matches!(mapped, WsClientError::ConnectionClosed));
641    }
642
643    // -----------------------------------------------------------------
644    // Structured tracing breadcrumbs.  The transport layer must never
645    // swallow a failure silently: callers (the session) discard recv
646    // errors in their generic `Disconnected(_)` arm and use
647    // `let _ = sender.send(...)` for accept/reject/fail/completeJson,
648    // so the only place those faults can surface is here.  Mirrors the
649    // `studio_worker::http` breadcrumb contract.
650    // -----------------------------------------------------------------
651    use crate::test_support::capture;
652
653    #[test]
654    fn classify_rejects_binary_frame_with_warn() {
655        let logs = capture(|| {
656            let step = classify_incoming(Ok(Message::Binary(vec![1, 2, 3].into())));
657            assert!(matches!(step, RecvStep::Fail(WsClientError::Protocol(_))));
658        });
659        assert!(logs.contains("WARN"), "expected WARN, got: {logs}");
660        assert!(
661            logs.contains("studio_worker::ws::client"),
662            "expected target, got: {logs}"
663        );
664        assert!(logs.contains("op=\"recv\""), "expected op field: {logs}");
665        assert!(logs.contains("binary"), "expected reason: {logs}");
666    }
667
668    #[test]
669    fn classify_warns_on_unparseable_text_frame() {
670        let logs = capture(|| {
671            let step = classify_incoming(Ok(Message::Text("not json".into())));
672            assert!(matches!(step, RecvStep::Fail(WsClientError::Protocol(_))));
673        });
674        assert!(logs.contains("WARN"), "expected WARN, got: {logs}");
675        assert!(logs.contains("op=\"recv\""), "expected op field: {logs}");
676    }
677
678    #[test]
679    fn classify_warns_on_4001_close_frame() {
680        let logs = capture(|| {
681            let frame = CloseFrame {
682                code: CloseCode::Library(4001),
683                reason: "invalid auth token".into(),
684            };
685            let step = classify_incoming(Ok(Message::Close(Some(frame))));
686            assert!(matches!(
687                step,
688                RecvStep::Closed(WsClientError::AuthFailed { .. })
689            ));
690        });
691        assert!(logs.contains("WARN"), "expected WARN, got: {logs}");
692        assert!(logs.contains("auth failed"), "expected reason: {logs}");
693    }
694
695    #[test]
696    fn classify_debug_logs_on_normal_close_frame() {
697        let logs = capture(|| {
698            let frame = CloseFrame {
699                code: CloseCode::Normal,
700                reason: "bye".into(),
701            };
702            let step = classify_incoming(Ok(Message::Close(Some(frame))));
703            assert!(matches!(
704                step,
705                RecvStep::Closed(WsClientError::ConnectionClosed)
706            ));
707        });
708        assert!(logs.contains("DEBUG"), "expected DEBUG, got: {logs}");
709        assert!(!logs.contains("WARN"), "normal close must not warn: {logs}");
710        assert!(logs.contains("server closed"), "expected message: {logs}");
711    }
712
713    #[test]
714    fn classify_yields_valid_frame_without_warning() {
715        let logs = capture(|| {
716            let json = serde_json::json!({ "type": "heartbeatAck" }).to_string();
717            let step = classify_incoming(Ok(Message::Text(json.into())));
718            assert!(matches!(
719                step,
720                RecvStep::Yield(WorkerOutbound::HeartbeatAck)
721            ));
722        });
723        assert!(
724            !logs.contains("WARN"),
725            "a valid frame should not warn: {logs}"
726        );
727    }
728
729    #[test]
730    fn classify_skips_control_frames() {
731        assert!(matches!(
732            classify_incoming(Ok(Message::Ping(Vec::new().into()))),
733            RecvStep::Skip
734        ));
735        assert!(matches!(
736            classify_incoming(Ok(Message::Pong(Vec::new().into()))),
737            RecvStep::Skip
738        ));
739    }
740
741    #[test]
742    fn classify_debug_logs_when_the_transport_read_closes_cleanly() {
743        // A clean close can surface as a stream `Err`
744        // (ConnectionClosed / AlreadyClosed) instead of a Close frame
745        // on some transports.  Both must fail the recv but stay at
746        // DEBUG so the logs aren't spammed on every expected reconnect.
747        for already_closed in [false, true] {
748            let logs = capture(move || {
749                let inner = if already_closed {
750                    TError::AlreadyClosed
751                } else {
752                    TError::ConnectionClosed
753                };
754                let step = classify_incoming(Err(inner));
755                assert!(matches!(
756                    step,
757                    RecvStep::Fail(WsClientError::ConnectionClosed)
758                ));
759            });
760            assert!(
761                logs.contains("DEBUG"),
762                "already_closed={already_closed}: expected DEBUG, got: {logs}"
763            );
764            assert!(
765                !logs.contains("WARN"),
766                "already_closed={already_closed}: a clean close must not warn: {logs}"
767            );
768            assert!(
769                logs.contains("connection closed by peer"),
770                "already_closed={already_closed}: expected message: {logs}"
771            );
772        }
773    }
774
775    #[test]
776    fn classify_warns_on_a_transport_read_error() {
777        // A genuine transport fault (not a clean close) must surface
778        // the recv failure at WARN: the session discards recv errors in
779        // its generic `Disconnected(_)` arm, so this breadcrumb is the
780        // only place an operator sees why the session dropped.
781        let logs = capture(|| {
782            let inner = TError::Io(std::io::Error::new(
783                std::io::ErrorKind::ConnectionReset,
784                "peer reset the connection",
785            ));
786            let step = classify_incoming(Err(inner));
787            assert!(matches!(step, RecvStep::Fail(WsClientError::Transport(_))));
788        });
789        assert!(logs.contains("WARN"), "expected WARN, got: {logs}");
790        assert!(logs.contains("op=\"recv\""), "expected op field: {logs}");
791        assert!(logs.contains("transport error"), "expected message: {logs}");
792    }
793
794    #[test]
795    fn frame_label_names_every_inbound_variant() {
796        use crate::types::WorkerCapabilities;
797        let caps = WorkerCapabilities {
798            machine_name: String::new(),
799            username: String::new(),
800            agent_version: String::new(),
801            engine: String::new(),
802            vram_total_gb: 0.0,
803            vram_threshold_gb: 0.0,
804            auto_enabled: false,
805            auto_start: false,
806            supported_models: vec![],
807            task_kinds: vec![],
808            supported_models_per_kind: Default::default(),
809        };
810        assert_eq!(
811            frame_label(&WorkerInbound::Hello(crate::ws::types::HelloFrame {
812                auth_token: String::new(),
813                capabilities: caps.clone(),
814            })),
815            "hello"
816        );
817        assert_eq!(
818            frame_label(&WorkerInbound::Heartbeat {
819                capabilities: caps,
820                current_job_id: None,
821            }),
822            "heartbeat"
823        );
824        assert_eq!(
825            frame_label(&WorkerInbound::Accept { job_id: "j".into() }),
826            "accept"
827        );
828        assert_eq!(
829            frame_label(&WorkerInbound::Reject {
830                job_id: "j".into(),
831                reason: "r".into(),
832                code: None,
833            }),
834            "reject"
835        );
836        assert_eq!(
837            frame_label(&WorkerInbound::CompleteJson {
838                job_id: "j".into(),
839                result: serde_json::Value::Null,
840                prompt: None,
841            }),
842            "completeJson"
843        );
844        assert_eq!(
845            frame_label(&WorkerInbound::Fail {
846                job_id: "j".into(),
847                error: "e".into(),
848                retryable: true,
849            }),
850            "fail"
851        );
852        assert_eq!(
853            frame_label(&WorkerInbound::LogBatch { entries: vec![] }),
854            "logBatch"
855        );
856        assert_eq!(frame_label(&WorkerInbound::ReadyForMore), "readyForMore");
857    }
858
859    #[test]
860    fn send_error_logs_warn_with_frame_label() {
861        let logs = capture(|| {
862            log_send_error(
863                &WorkerInbound::Accept {
864                    job_id: "j-1".into(),
865                },
866                &WsClientError::ConnectionClosed,
867            );
868        });
869        assert!(logs.contains("WARN"), "expected WARN, got: {logs}");
870        assert!(logs.contains("op=\"send\""), "expected op field: {logs}");
871        assert!(
872            logs.contains("frame=\"accept\""),
873            "expected frame label: {logs}"
874        );
875    }
876
877    #[test]
878    fn serialize_frame_encodes_camel_case_wire_json() {
879        // Both `WsSender::send` and `WsClient::send` route through
880        // `serialize_frame`, so the on-the-wire encoding can't drift
881        // between the split and monolithic send paths. Pin the wire
882        // shape for a representative frame.
883        let json = serialize_frame(&WorkerInbound::Accept {
884            job_id: "j-9".into(),
885        })
886        .expect("a well-formed frame must serialise");
887        assert_eq!(json, r#"{"type":"accept","jobId":"j-9"}"#);
888    }
889
890    #[tokio::test]
891    async fn connect_times_out_against_a_stalling_upgrade() {
892        // A listener that accepts the TCP connection but never answers the WS upgrade. Without the
893        // connect timeout this blocks forever; with it, a transport error must surface fast.
894        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
895        let addr = listener.local_addr().unwrap();
896        tokio::spawn(async move {
897            let _accepted = listener.accept().await; // hold the socket, never upgrade
898            tokio::time::sleep(Duration::from_secs(30)).await;
899        });
900        let url = format!("http://{addr}/graphics/api");
901        let started = Instant::now();
902        let result = connect_inner(&url, "w", "tok", Duration::from_millis(150)).await;
903        assert!(
904            matches!(result, Err(WsClientError::Transport(_))),
905            "expected a transport timeout, got {result:?}"
906        );
907        assert!(
908            started.elapsed() < Duration::from_secs(2),
909            "connect must time out promptly, took {:?}",
910            started.elapsed()
911        );
912    }
913
914    #[test]
915    fn connect_failure_logs_warn_breadcrumb() {
916        // Port 1 has nothing listening, so the upgrade fails fast with
917        // a transport error.  No server required — deterministic.
918        let logs = capture(|| {
919            let rt = tokio::runtime::Builder::new_current_thread()
920                .enable_all()
921                .build()
922                .unwrap();
923            let result = rt.block_on(connect("http://127.0.0.1:1/graphics/api", "w-err", "tok"));
924            assert!(result.is_err(), "connect to a dead port should fail");
925        });
926        assert!(logs.contains("WARN"), "expected WARN, got: {logs}");
927        assert!(logs.contains("op=\"connect\""), "expected op field: {logs}");
928        assert!(
929            logs.contains("websocket connect failed"),
930            "expected message: {logs}"
931        );
932        assert!(
933            logs.contains("worker_id=\"w-err\""),
934            "expected worker_id field: {logs}"
935        );
936    }
937
938    // -----------------------------------------------------------------
939    // From<TError> — HTTP upgrade-error mapping.  A 401 stays a typed
940    // AuthFailed so the runtime surfaces a friendly token hint; every
941    // other status is a transient server-side fault the reconnect loop
942    // retries.  For the latter we fold the studio's response body —
943    // which carries the JSON error `reference` id Sentry also shows —
944    // into the breadcrumb so the worker's reconnect warning can be
945    // correlated with the studio-side failure.  tungstenite's own
946    // `Error::Http` Display keeps only the status and drops the body.
947    // -----------------------------------------------------------------
948
949    fn http_error(status: u16, body: Option<&[u8]>) -> TError {
950        let response = tokio_tungstenite::tungstenite::http::Response::builder()
951            .status(status)
952            .body(body.map(<[u8]>::to_vec))
953            .expect("a valid response");
954        TError::Http(response)
955    }
956
957    #[test]
958    fn http_401_upgrade_maps_to_auth_failed_ignoring_body() {
959        let err = WsClientError::from(http_error(401, Some(b"any body")));
960        assert!(
961            matches!(err, WsClientError::AuthFailed { .. }),
962            "401 must stay AuthFailed, got {err:?}"
963        );
964    }
965
966    #[test]
967    fn http_500_upgrade_surfaces_status_and_reference_body() {
968        let err = WsClientError::from(http_error(
969            500,
970            Some(b"internal error; reference = q1mtuhheh7en3lfqoofvgfgd"),
971        ));
972        let WsClientError::Transport(msg) = err else {
973            panic!("a non-401 HTTP error must map to Transport, got {err:?}");
974        };
975        assert!(msg.contains("500"), "status must be present: {msg}");
976        assert!(
977            msg.contains("q1mtuhheh7en3lfqoofvgfgd"),
978            "the studio's error reference id must survive into the breadcrumb: {msg}"
979        );
980    }
981
982    #[test]
983    fn http_503_upgrade_without_body_keeps_just_the_status() {
984        let err = WsClientError::from(http_error(503, None));
985        let WsClientError::Transport(msg) = err else {
986            panic!("expected Transport, got {err:?}");
987        };
988        assert!(msg.contains("503"), "status must be present: {msg}");
989        assert!(
990            !msg.trim_end().ends_with(':'),
991            "a bodyless error must not leave a dangling colon: {msg}"
992        );
993    }
994
995    #[test]
996    fn http_upgrade_blank_body_is_treated_as_no_body() {
997        let err = WsClientError::from(http_error(500, Some(b"   \n\t ")));
998        let WsClientError::Transport(msg) = err else {
999            panic!("expected Transport, got {err:?}");
1000        };
1001        assert!(
1002            !msg.trim_end().ends_with(':'),
1003            "a whitespace-only body must not leave a dangling colon: {msg}"
1004        );
1005    }
1006
1007    #[test]
1008    fn http_upgrade_error_body_is_clipped() {
1009        let big = "x".repeat(5_000);
1010        let err = WsClientError::from(http_error(502, Some(big.as_bytes())));
1011        let WsClientError::Transport(msg) = err else {
1012            panic!("expected Transport, got {err:?}");
1013        };
1014        assert!(
1015            msg.chars().count() < big.len(),
1016            "a huge error page must be clipped, got {} chars",
1017            msg.chars().count()
1018        );
1019        assert!(
1020            msg.contains('\u{2026}'),
1021            "a clipped body must carry an ellipsis: {msg}"
1022        );
1023    }
1024
1025    #[test]
1026    fn http_upgrade_error_body_clips_on_char_boundary_not_mid_codepoint() {
1027        // The studio's failed-`/connect` upgrade body (the one carrying
1028        // the Sentry `reference` id — see STUDIO-WORKER-1) can contain
1029        // multibyte characters: a Cloudflare HTML error page, a
1030        // non-ASCII operator hostname echoed back, an em-dash, etc.
1031        // `clip_error_body` documents that it clips on *character*
1032        // boundaries so such a body is never split mid-codepoint.  Pad
1033        // with ASCII up to one char short of the limit, then place a
1034        // 3-byte char straddling the byte that a naive `&body[..N]`
1035        // byte-slice would cut on — char-based clipping keeps it whole,
1036        // a byte-slice regression would panic before producing any
1037        // message at all.
1038        let body = format!(
1039            "{}{}",
1040            "a".repeat(HTTP_ERROR_BODY_MAX_CHARS - 1),
1041            "\u{4e16}".repeat(101)
1042        );
1043        let err = WsClientError::from(http_error(502, Some(body.as_bytes())));
1044        let WsClientError::Transport(msg) = err else {
1045            panic!("expected Transport, got {err:?}");
1046        };
1047        assert!(msg.contains("502"), "status must survive: {msg}");
1048        assert!(
1049            msg.contains('\u{2026}'),
1050            "an over-limit body must be clipped: {msg}"
1051        );
1052        assert!(
1053            msg.contains('\u{4e16}'),
1054            "the char straddling the clip point must survive whole: {msg}"
1055        );
1056        assert!(
1057            !msg.contains('\u{fffd}'),
1058            "no codepoint may be split (no replacement char): {msg}"
1059        );
1060    }
1061
1062    #[test]
1063    fn clip_error_body_keeps_an_exactly_at_limit_body_verbatim() {
1064        // The clip is gated on `chars().count() > MAX`, so a body of
1065        // exactly `MAX` chars must pass through untouched (no ellipsis),
1066        // and one char over must clip.  Pins the off-by-one boundary so
1067        // a `>=`-vs-`>` regression can't silently start truncating a
1068        // body that fit.
1069        let at_limit = "x".repeat(HTTP_ERROR_BODY_MAX_CHARS);
1070        let clipped = clip_error_body(&at_limit);
1071        assert_eq!(
1072            clipped, at_limit,
1073            "a body exactly at the limit must be returned verbatim"
1074        );
1075        assert!(
1076            !clipped.contains('\u{2026}'),
1077            "an at-limit body must not gain an ellipsis: {clipped}"
1078        );
1079
1080        let over_limit = "x".repeat(HTTP_ERROR_BODY_MAX_CHARS + 1);
1081        let clipped = clip_error_body(&over_limit);
1082        assert_eq!(
1083            clipped.chars().count(),
1084            HTTP_ERROR_BODY_MAX_CHARS + 1,
1085            "an over-limit body keeps MAX chars plus the ellipsis"
1086        );
1087        assert!(
1088            clipped.ends_with('\u{2026}'),
1089            "an over-limit body must end with an ellipsis: {clipped}"
1090        );
1091    }
1092}