1use std::convert::TryFrom;
21use std::time::{Duration, Instant};
22
23use std::sync::Arc;
24
25use futures_util::stream::{SplitSink, SplitStream};
26use futures_util::{SinkExt, StreamExt};
27use tokio::net::TcpStream;
28use tokio::sync::Mutex;
29use tokio_tungstenite::tungstenite::client::IntoClientRequest;
30use tokio_tungstenite::tungstenite::http::HeaderValue;
31use tokio_tungstenite::tungstenite::http::Response;
32use tokio_tungstenite::tungstenite::http::StatusCode;
33use tokio_tungstenite::tungstenite::protocol::{frame::coding::CloseCode, CloseFrame};
34use tokio_tungstenite::tungstenite::{Error as TError, Message};
35use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
36use tracing::{debug, warn};
37use url::Url;
38
39use crate::ws::types::{WorkerInbound, WorkerOutbound};
40
41pub const SUBPROTOCOL: &str = "studio-worker-v1";
42
43const TRACE_TARGET: &str = "studio_worker::ws::client";
48const API_PREFIX: &str = "/graphics/api";
52
53const CONNECT_TIMEOUT: Duration = Duration::from_secs(15);
57
58pub type WsResult<T> = Result<T, WsClientError>;
60
61#[derive(Debug, thiserror::Error)]
64pub enum WsClientError {
65 #[error("auth failed: {reason}")]
67 AuthFailed { reason: String },
68
69 #[error("connection closed by server")]
72 ConnectionClosed,
73
74 #[error("ws transport error: {0}")]
76 Transport(String),
77
78 #[error("protocol error: {0}")]
80 Protocol(String),
81}
82
83impl From<TError> for WsClientError {
84 fn from(value: TError) -> Self {
85 match value {
86 TError::Http(response) if response.status() == StatusCode::UNAUTHORIZED => {
87 WsClientError::AuthFailed {
88 reason: "401 on websocket upgrade".to_string(),
89 }
90 }
91 TError::Http(response) => {
96 WsClientError::Transport(http_upgrade_error_message(&response))
97 }
98 TError::ConnectionClosed | TError::AlreadyClosed => WsClientError::ConnectionClosed,
99 other => WsClientError::Transport(other.to_string()),
100 }
101 }
102}
103
104const HTTP_ERROR_BODY_MAX_CHARS: usize = 300;
109
110fn http_upgrade_error_message(response: &Response<Option<Vec<u8>>>) -> String {
120 let status = response.status();
121 let body = response.body().as_deref().and_then(|bytes| {
122 let decoded = String::from_utf8_lossy(bytes);
123 let trimmed = decoded.trim();
124 if trimmed.is_empty() {
125 return None;
126 }
127 Some(clip_error_body(trimmed))
128 });
129 match body {
130 Some(b) => format!("HTTP {status} on websocket upgrade: {b}"),
131 None => format!("HTTP {status} on websocket upgrade"),
132 }
133}
134
135fn clip_error_body(body: &str) -> String {
139 if body.chars().count() > HTTP_ERROR_BODY_MAX_CHARS {
140 let mut clipped: String = body.chars().take(HTTP_ERROR_BODY_MAX_CHARS).collect();
141 clipped.push('\u{2026}');
142 clipped
143 } else {
144 body.to_string()
145 }
146}
147
148fn build_connect_url(base_url: &str, worker_id: &str) -> WsResult<Url> {
150 let mut url = Url::parse(base_url)
151 .map_err(|e| WsClientError::Transport(format!("invalid base url: {e}")))?;
152 let new_scheme = match url.scheme() {
153 "http" => Some("ws"),
154 "https" => Some("wss"),
155 "ws" | "wss" => None, other => {
157 return Err(WsClientError::Transport(format!(
158 "unsupported scheme: {other}"
159 )))
160 }
161 };
162 if let Some(scheme) = new_scheme {
163 url.set_scheme(scheme)
164 .map_err(|_| WsClientError::Transport("set_scheme failed".to_string()))?;
165 }
166 let trimmed_path = url.path().trim_end_matches('/');
167 let prefixed = if trimmed_path.ends_with(API_PREFIX) {
171 trimmed_path.to_string()
172 } else {
173 format!("{trimmed_path}{API_PREFIX}")
174 };
175 let new_path = format!("{prefixed}/workers/{worker_id}/connect");
176 url.set_path(&new_path);
177 Ok(url)
178}
179
180pub async fn connect(base_url: &str, worker_id: &str, auth_token: &str) -> WsResult<WsClient> {
187 let started = Instant::now();
188 let result = connect_inner(base_url, worker_id, auth_token, CONNECT_TIMEOUT).await;
189 let elapsed_ms = started.elapsed().as_millis() as u64;
190 match &result {
191 Ok(_) => debug!(
192 target: TRACE_TARGET,
193 op = "connect",
194 worker_id,
195 elapsed_ms,
196 "websocket established"
197 ),
198 Err(e) => warn!(
199 target: TRACE_TARGET,
200 op = "connect",
201 worker_id,
202 elapsed_ms,
203 error = %e,
204 "websocket connect failed"
205 ),
206 }
207 result
208}
209
210async fn connect_inner(
211 base_url: &str,
212 worker_id: &str,
213 auth_token: &str,
214 connect_timeout: Duration,
215) -> WsResult<WsClient> {
216 let url = build_connect_url(base_url, worker_id)?;
217 debug!(
218 target: TRACE_TARGET,
219 op = "connect",
220 worker_id,
221 url = %url,
222 "opening websocket"
223 );
224 let mut request = url
225 .as_str()
226 .into_client_request()
227 .map_err(WsClientError::from)?;
228 let headers = request.headers_mut();
229 headers.insert(
230 "Authorization",
231 HeaderValue::try_from(format!("Bearer {auth_token}"))
232 .map_err(|e| WsClientError::Transport(format!("invalid auth header: {e}")))?,
233 );
234 headers.insert(
235 "Sec-WebSocket-Protocol",
236 HeaderValue::from_static(SUBPROTOCOL),
237 );
238
239 let (stream, _response) = match tokio::time::timeout(
240 connect_timeout,
241 tokio_tungstenite::connect_async(request),
242 )
243 .await
244 {
245 Ok(result) => result?,
246 Err(_elapsed) => {
247 return Err(WsClientError::Transport(format!(
248 "connect timed out after {connect_timeout:?}"
249 )))
250 }
251 };
252 let (sink, source) = stream.split();
253 Ok(WsClient {
254 sink,
255 source,
256 closed: false,
257 })
258}
259
260type WsSink = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
261type WsSource = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
262
263#[allow(missing_debug_implementations)]
266pub struct WsClient {
267 sink: WsSink,
268 source: WsSource,
269 closed: bool,
270}
271
272impl WsClient {
273 pub fn split(self) -> (WsSender, WsReceiver) {
278 let sink = Arc::new(Mutex::new(self.sink));
279 (
280 WsSender { sink },
281 WsReceiver {
282 source: self.source,
283 closed: false,
284 },
285 )
286 }
287}
288
289#[derive(Clone)]
293#[allow(missing_debug_implementations)]
294pub struct WsSender {
295 sink: Arc<Mutex<WsSink>>,
296}
297
298impl WsSender {
299 pub async fn send(&self, frame: &WorkerInbound) -> WsResult<()> {
300 let text = serialize_frame(frame)?;
301 let mut guard = self.sink.lock().await;
302 guard
303 .send(Message::Text(text.into()))
304 .await
305 .map_err(|e| map_send_failure(frame, e))
306 }
307
308 pub async fn close(&self, code: u16, reason: &str) -> WsResult<()> {
309 debug!(target: TRACE_TARGET, op = "close", code, reason, "closing websocket");
310 let frame = CloseFrame {
311 code: CloseCode::from(code),
312 reason: reason.to_owned().into(),
313 };
314 let mut guard = self.sink.lock().await;
315 if tokio::time::timeout(
316 Duration::from_secs(5),
317 guard.send(Message::Close(Some(frame))),
318 )
319 .await
320 .is_err()
321 {
322 warn!(target: TRACE_TARGET, op = "close", code, "timed out sending close frame");
323 }
324 Ok(())
325 }
326}
327
328#[allow(missing_debug_implementations)]
330pub struct WsReceiver {
331 source: WsSource,
332 closed: bool,
333}
334
335impl WsReceiver {
336 pub async fn recv(&mut self) -> WsResult<Option<WorkerOutbound>> {
340 recv_next(&mut self.source, &mut self.closed).await
341 }
342}
343
344impl std::fmt::Debug for WsClient {
345 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
346 f.debug_struct("WsClient")
347 .field("closed", &self.closed)
348 .finish()
349 }
350}
351
352impl WsClient {
353 pub async fn send(&mut self, frame: &WorkerInbound) -> WsResult<()> {
355 let text = serialize_frame(frame)?;
356 self.sink
357 .send(Message::Text(text.into()))
358 .await
359 .map_err(|e| map_send_failure(frame, e))
360 }
361
362 pub async fn recv(&mut self) -> WsResult<Option<WorkerOutbound>> {
367 recv_next(&mut self.source, &mut self.closed).await
368 }
369
370 pub async fn close(&mut self, code: u16, reason: &str) -> WsResult<()> {
372 if self.closed {
373 return Ok(());
374 }
375 self.closed = true;
376 debug!(target: TRACE_TARGET, op = "close", code, reason, "closing websocket");
377 let frame = CloseFrame {
378 code: CloseCode::from(code),
379 reason: reason.to_owned().into(),
380 };
381 if tokio::time::timeout(
383 Duration::from_secs(5),
384 self.sink.send(Message::Close(Some(frame))),
385 )
386 .await
387 .is_err()
388 {
389 warn!(target: TRACE_TARGET, op = "close", code, "timed out sending close frame");
390 }
391 Ok(())
392 }
393}
394
395fn frame_label(frame: &WorkerInbound) -> &'static str {
399 match frame {
400 WorkerInbound::Hello(_) => "hello",
401 WorkerInbound::Heartbeat { .. } => "heartbeat",
402 WorkerInbound::Accept { .. } => "accept",
403 WorkerInbound::Reject { .. } => "reject",
404 WorkerInbound::CompleteJson { .. } => "completeJson",
405 WorkerInbound::Fail { .. } => "fail",
406 WorkerInbound::LogBatch { .. } => "logBatch",
407 WorkerInbound::ReadyForMore => "readyForMore",
408 }
409}
410
411fn log_send_error(frame: &WorkerInbound, err: &WsClientError) {
415 warn!(
416 target: TRACE_TARGET,
417 op = "send",
418 frame = frame_label(frame),
419 error = %err,
420 "failed to send frame"
421 );
422}
423
424fn serialize_frame(frame: &WorkerInbound) -> WsResult<String> {
430 serde_json::to_string(frame).map_err(|e| {
431 let err = WsClientError::Protocol(e.to_string());
432 log_send_error(frame, &err);
433 err
434 })
435}
436
437fn map_send_failure(frame: &WorkerInbound, e: TError) -> WsClientError {
440 let err = WsClientError::from(e);
441 log_send_error(frame, &err);
442 err
443}
444
445async fn recv_next(source: &mut WsSource, closed: &mut bool) -> WsResult<Option<WorkerOutbound>> {
452 if *closed {
453 return Ok(None);
454 }
455 while let Some(item) = source.next().await {
456 match classify_incoming(item) {
457 RecvStep::Yield(frame) => return Ok(Some(frame)),
458 RecvStep::Skip => continue,
459 RecvStep::Fail(e) => return Err(e),
460 RecvStep::Closed(e) => {
461 *closed = true;
462 return Err(e);
463 }
464 }
465 }
466 *closed = true;
467 debug!(target: TRACE_TARGET, op = "recv", "stream ended (no close frame)");
468 Ok(None)
469}
470
471enum RecvStep {
477 Yield(WorkerOutbound),
479 Skip,
481 Fail(WsClientError),
483 Closed(WsClientError),
485}
486
487fn classify_incoming(item: Result<Message, TError>) -> RecvStep {
490 match item {
491 Ok(Message::Text(text)) => match serde_json::from_str::<WorkerOutbound>(&text) {
492 Ok(frame) => RecvStep::Yield(frame),
493 Err(e) => {
494 warn!(
495 target: TRACE_TARGET,
496 op = "recv",
497 error = %e,
498 "dropping unparseable text frame"
499 );
500 RecvStep::Fail(WsClientError::Protocol(e.to_string()))
501 }
502 },
503 Ok(Message::Binary(_)) => {
504 warn!(
505 target: TRACE_TARGET,
506 op = "recv",
507 "rejecting unexpected binary frame"
508 );
509 RecvStep::Fail(WsClientError::Protocol(
510 "unexpected binary frame".to_string(),
511 ))
512 }
513 Ok(Message::Close(frame)) => {
514 let err = close_frame_to_error(frame);
515 match &err {
516 WsClientError::AuthFailed { reason } => warn!(
517 target: TRACE_TARGET,
518 op = "recv",
519 reason = %reason,
520 "server closed connection: auth failed"
521 ),
522 _ => debug!(
523 target: TRACE_TARGET,
524 op = "recv",
525 "server closed connection"
526 ),
527 }
528 RecvStep::Closed(err)
529 }
530 Ok(_) => RecvStep::Skip,
532 Err(e) => {
533 let mapped = WsClientError::from(e);
534 match &mapped {
535 WsClientError::ConnectionClosed => debug!(
539 target: TRACE_TARGET,
540 op = "recv",
541 "connection closed by peer"
542 ),
543 other => warn!(
544 target: TRACE_TARGET,
545 op = "recv",
546 error = %other,
547 "transport error while reading frame"
548 ),
549 }
550 RecvStep::Fail(mapped)
551 }
552 }
553}
554
555fn close_frame_to_error(frame: Option<CloseFrame>) -> WsClientError {
556 if let Some(frame) = frame {
557 let code: u16 = frame.code.into();
558 if code == 4001 {
559 return WsClientError::AuthFailed {
560 reason: format!("server closed 4001: {}", frame.reason),
561 };
562 }
563 }
564 WsClientError::ConnectionClosed
565}
566
567#[cfg(test)]
568mod tests {
569 use super::*;
570
571 #[test]
572 fn build_connect_url_http_to_ws() {
573 let url = build_connect_url("http://api.example/graphics/api", "w-1").unwrap();
574 assert_eq!(url.scheme(), "ws");
575 assert!(url.path().ends_with("/workers/w-1/connect"));
576 }
577
578 #[test]
579 fn build_connect_url_https_to_wss() {
580 let url = build_connect_url("https://api.example/graphics/api/", "w-2").unwrap();
581 assert_eq!(url.scheme(), "wss");
582 assert_eq!(url.path(), "/graphics/api/workers/w-2/connect");
583 }
584
585 #[test]
586 fn build_connect_url_appends_graphics_api_prefix_when_missing() {
587 let url = build_connect_url("http://localhost:9790", "w-3").unwrap();
588 assert_eq!(url.scheme(), "ws");
589 assert_eq!(url.path(), "/graphics/api/workers/w-3/connect");
590 }
591
592 #[test]
593 fn build_connect_url_preserves_existing_ws_scheme() {
594 let url = build_connect_url("ws://localhost:9790/x", "w").unwrap();
595 assert_eq!(url.scheme(), "ws");
596 }
597
598 #[test]
599 fn build_connect_url_rejects_unknown_scheme() {
600 let err = build_connect_url("ftp://nope/", "w").unwrap_err();
601 assert!(matches!(err, WsClientError::Transport(_)));
602 }
603
604 #[test]
605 fn build_connect_url_rejects_invalid_url() {
606 let err = build_connect_url("not a url", "w").unwrap_err();
607 assert!(matches!(err, WsClientError::Transport(_)));
608 }
609
610 #[test]
611 fn close_frame_4001_maps_to_auth_failed() {
612 let frame = CloseFrame {
613 code: CloseCode::Library(4001),
614 reason: "bad token".into(),
615 };
616 let err = close_frame_to_error(Some(frame));
617 assert!(matches!(err, WsClientError::AuthFailed { .. }));
618 }
619
620 #[test]
621 fn close_frame_other_codes_map_to_connection_closed() {
622 let frame = CloseFrame {
623 code: CloseCode::Normal,
624 reason: "bye".into(),
625 };
626 let err = close_frame_to_error(Some(frame));
627 assert!(matches!(err, WsClientError::ConnectionClosed));
628 }
629
630 #[test]
631 fn close_frame_missing_maps_to_connection_closed() {
632 let err = close_frame_to_error(None);
633 assert!(matches!(err, WsClientError::ConnectionClosed));
634 }
635
636 #[test]
637 fn transport_error_round_trips_through_from_impl() {
638 let inner = TError::AlreadyClosed;
639 let mapped: WsClientError = inner.into();
640 assert!(matches!(mapped, WsClientError::ConnectionClosed));
641 }
642
643 use crate::test_support::capture;
652
653 #[test]
654 fn classify_rejects_binary_frame_with_warn() {
655 let logs = capture(|| {
656 let step = classify_incoming(Ok(Message::Binary(vec![1, 2, 3].into())));
657 assert!(matches!(step, RecvStep::Fail(WsClientError::Protocol(_))));
658 });
659 assert!(logs.contains("WARN"), "expected WARN, got: {logs}");
660 assert!(
661 logs.contains("studio_worker::ws::client"),
662 "expected target, got: {logs}"
663 );
664 assert!(logs.contains("op=\"recv\""), "expected op field: {logs}");
665 assert!(logs.contains("binary"), "expected reason: {logs}");
666 }
667
668 #[test]
669 fn classify_warns_on_unparseable_text_frame() {
670 let logs = capture(|| {
671 let step = classify_incoming(Ok(Message::Text("not json".into())));
672 assert!(matches!(step, RecvStep::Fail(WsClientError::Protocol(_))));
673 });
674 assert!(logs.contains("WARN"), "expected WARN, got: {logs}");
675 assert!(logs.contains("op=\"recv\""), "expected op field: {logs}");
676 }
677
678 #[test]
679 fn classify_warns_on_4001_close_frame() {
680 let logs = capture(|| {
681 let frame = CloseFrame {
682 code: CloseCode::Library(4001),
683 reason: "invalid auth token".into(),
684 };
685 let step = classify_incoming(Ok(Message::Close(Some(frame))));
686 assert!(matches!(
687 step,
688 RecvStep::Closed(WsClientError::AuthFailed { .. })
689 ));
690 });
691 assert!(logs.contains("WARN"), "expected WARN, got: {logs}");
692 assert!(logs.contains("auth failed"), "expected reason: {logs}");
693 }
694
695 #[test]
696 fn classify_debug_logs_on_normal_close_frame() {
697 let logs = capture(|| {
698 let frame = CloseFrame {
699 code: CloseCode::Normal,
700 reason: "bye".into(),
701 };
702 let step = classify_incoming(Ok(Message::Close(Some(frame))));
703 assert!(matches!(
704 step,
705 RecvStep::Closed(WsClientError::ConnectionClosed)
706 ));
707 });
708 assert!(logs.contains("DEBUG"), "expected DEBUG, got: {logs}");
709 assert!(!logs.contains("WARN"), "normal close must not warn: {logs}");
710 assert!(logs.contains("server closed"), "expected message: {logs}");
711 }
712
713 #[test]
714 fn classify_yields_valid_frame_without_warning() {
715 let logs = capture(|| {
716 let json = serde_json::json!({ "type": "heartbeatAck" }).to_string();
717 let step = classify_incoming(Ok(Message::Text(json.into())));
718 assert!(matches!(
719 step,
720 RecvStep::Yield(WorkerOutbound::HeartbeatAck)
721 ));
722 });
723 assert!(
724 !logs.contains("WARN"),
725 "a valid frame should not warn: {logs}"
726 );
727 }
728
729 #[test]
730 fn classify_skips_control_frames() {
731 assert!(matches!(
732 classify_incoming(Ok(Message::Ping(Vec::new().into()))),
733 RecvStep::Skip
734 ));
735 assert!(matches!(
736 classify_incoming(Ok(Message::Pong(Vec::new().into()))),
737 RecvStep::Skip
738 ));
739 }
740
741 #[test]
742 fn classify_debug_logs_when_the_transport_read_closes_cleanly() {
743 for already_closed in [false, true] {
748 let logs = capture(move || {
749 let inner = if already_closed {
750 TError::AlreadyClosed
751 } else {
752 TError::ConnectionClosed
753 };
754 let step = classify_incoming(Err(inner));
755 assert!(matches!(
756 step,
757 RecvStep::Fail(WsClientError::ConnectionClosed)
758 ));
759 });
760 assert!(
761 logs.contains("DEBUG"),
762 "already_closed={already_closed}: expected DEBUG, got: {logs}"
763 );
764 assert!(
765 !logs.contains("WARN"),
766 "already_closed={already_closed}: a clean close must not warn: {logs}"
767 );
768 assert!(
769 logs.contains("connection closed by peer"),
770 "already_closed={already_closed}: expected message: {logs}"
771 );
772 }
773 }
774
775 #[test]
776 fn classify_warns_on_a_transport_read_error() {
777 let logs = capture(|| {
782 let inner = TError::Io(std::io::Error::new(
783 std::io::ErrorKind::ConnectionReset,
784 "peer reset the connection",
785 ));
786 let step = classify_incoming(Err(inner));
787 assert!(matches!(step, RecvStep::Fail(WsClientError::Transport(_))));
788 });
789 assert!(logs.contains("WARN"), "expected WARN, got: {logs}");
790 assert!(logs.contains("op=\"recv\""), "expected op field: {logs}");
791 assert!(logs.contains("transport error"), "expected message: {logs}");
792 }
793
794 #[test]
795 fn frame_label_names_every_inbound_variant() {
796 use crate::types::WorkerCapabilities;
797 let caps = WorkerCapabilities {
798 machine_name: String::new(),
799 username: String::new(),
800 agent_version: String::new(),
801 engine: String::new(),
802 vram_total_gb: 0.0,
803 vram_threshold_gb: 0.0,
804 auto_enabled: false,
805 auto_start: false,
806 supported_models: vec![],
807 task_kinds: vec![],
808 supported_models_per_kind: Default::default(),
809 };
810 assert_eq!(
811 frame_label(&WorkerInbound::Hello(crate::ws::types::HelloFrame {
812 auth_token: String::new(),
813 capabilities: caps.clone(),
814 })),
815 "hello"
816 );
817 assert_eq!(
818 frame_label(&WorkerInbound::Heartbeat {
819 capabilities: caps,
820 current_job_id: None,
821 }),
822 "heartbeat"
823 );
824 assert_eq!(
825 frame_label(&WorkerInbound::Accept { job_id: "j".into() }),
826 "accept"
827 );
828 assert_eq!(
829 frame_label(&WorkerInbound::Reject {
830 job_id: "j".into(),
831 reason: "r".into(),
832 code: None,
833 }),
834 "reject"
835 );
836 assert_eq!(
837 frame_label(&WorkerInbound::CompleteJson {
838 job_id: "j".into(),
839 result: serde_json::Value::Null,
840 prompt: None,
841 }),
842 "completeJson"
843 );
844 assert_eq!(
845 frame_label(&WorkerInbound::Fail {
846 job_id: "j".into(),
847 error: "e".into(),
848 retryable: true,
849 }),
850 "fail"
851 );
852 assert_eq!(
853 frame_label(&WorkerInbound::LogBatch { entries: vec![] }),
854 "logBatch"
855 );
856 assert_eq!(frame_label(&WorkerInbound::ReadyForMore), "readyForMore");
857 }
858
859 #[test]
860 fn send_error_logs_warn_with_frame_label() {
861 let logs = capture(|| {
862 log_send_error(
863 &WorkerInbound::Accept {
864 job_id: "j-1".into(),
865 },
866 &WsClientError::ConnectionClosed,
867 );
868 });
869 assert!(logs.contains("WARN"), "expected WARN, got: {logs}");
870 assert!(logs.contains("op=\"send\""), "expected op field: {logs}");
871 assert!(
872 logs.contains("frame=\"accept\""),
873 "expected frame label: {logs}"
874 );
875 }
876
877 #[test]
878 fn serialize_frame_encodes_camel_case_wire_json() {
879 let json = serialize_frame(&WorkerInbound::Accept {
884 job_id: "j-9".into(),
885 })
886 .expect("a well-formed frame must serialise");
887 assert_eq!(json, r#"{"type":"accept","jobId":"j-9"}"#);
888 }
889
890 #[tokio::test]
891 async fn connect_times_out_against_a_stalling_upgrade() {
892 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
895 let addr = listener.local_addr().unwrap();
896 tokio::spawn(async move {
897 let _accepted = listener.accept().await; tokio::time::sleep(Duration::from_secs(30)).await;
899 });
900 let url = format!("http://{addr}/graphics/api");
901 let started = Instant::now();
902 let result = connect_inner(&url, "w", "tok", Duration::from_millis(150)).await;
903 assert!(
904 matches!(result, Err(WsClientError::Transport(_))),
905 "expected a transport timeout, got {result:?}"
906 );
907 assert!(
908 started.elapsed() < Duration::from_secs(2),
909 "connect must time out promptly, took {:?}",
910 started.elapsed()
911 );
912 }
913
914 #[test]
915 fn connect_failure_logs_warn_breadcrumb() {
916 let logs = capture(|| {
919 let rt = tokio::runtime::Builder::new_current_thread()
920 .enable_all()
921 .build()
922 .unwrap();
923 let result = rt.block_on(connect("http://127.0.0.1:1/graphics/api", "w-err", "tok"));
924 assert!(result.is_err(), "connect to a dead port should fail");
925 });
926 assert!(logs.contains("WARN"), "expected WARN, got: {logs}");
927 assert!(logs.contains("op=\"connect\""), "expected op field: {logs}");
928 assert!(
929 logs.contains("websocket connect failed"),
930 "expected message: {logs}"
931 );
932 assert!(
933 logs.contains("worker_id=\"w-err\""),
934 "expected worker_id field: {logs}"
935 );
936 }
937
938 fn http_error(status: u16, body: Option<&[u8]>) -> TError {
950 let response = tokio_tungstenite::tungstenite::http::Response::builder()
951 .status(status)
952 .body(body.map(<[u8]>::to_vec))
953 .expect("a valid response");
954 TError::Http(response)
955 }
956
957 #[test]
958 fn http_401_upgrade_maps_to_auth_failed_ignoring_body() {
959 let err = WsClientError::from(http_error(401, Some(b"any body")));
960 assert!(
961 matches!(err, WsClientError::AuthFailed { .. }),
962 "401 must stay AuthFailed, got {err:?}"
963 );
964 }
965
966 #[test]
967 fn http_500_upgrade_surfaces_status_and_reference_body() {
968 let err = WsClientError::from(http_error(
969 500,
970 Some(b"internal error; reference = q1mtuhheh7en3lfqoofvgfgd"),
971 ));
972 let WsClientError::Transport(msg) = err else {
973 panic!("a non-401 HTTP error must map to Transport, got {err:?}");
974 };
975 assert!(msg.contains("500"), "status must be present: {msg}");
976 assert!(
977 msg.contains("q1mtuhheh7en3lfqoofvgfgd"),
978 "the studio's error reference id must survive into the breadcrumb: {msg}"
979 );
980 }
981
982 #[test]
983 fn http_503_upgrade_without_body_keeps_just_the_status() {
984 let err = WsClientError::from(http_error(503, None));
985 let WsClientError::Transport(msg) = err else {
986 panic!("expected Transport, got {err:?}");
987 };
988 assert!(msg.contains("503"), "status must be present: {msg}");
989 assert!(
990 !msg.trim_end().ends_with(':'),
991 "a bodyless error must not leave a dangling colon: {msg}"
992 );
993 }
994
995 #[test]
996 fn http_upgrade_blank_body_is_treated_as_no_body() {
997 let err = WsClientError::from(http_error(500, Some(b" \n\t ")));
998 let WsClientError::Transport(msg) = err else {
999 panic!("expected Transport, got {err:?}");
1000 };
1001 assert!(
1002 !msg.trim_end().ends_with(':'),
1003 "a whitespace-only body must not leave a dangling colon: {msg}"
1004 );
1005 }
1006
1007 #[test]
1008 fn http_upgrade_error_body_is_clipped() {
1009 let big = "x".repeat(5_000);
1010 let err = WsClientError::from(http_error(502, Some(big.as_bytes())));
1011 let WsClientError::Transport(msg) = err else {
1012 panic!("expected Transport, got {err:?}");
1013 };
1014 assert!(
1015 msg.chars().count() < big.len(),
1016 "a huge error page must be clipped, got {} chars",
1017 msg.chars().count()
1018 );
1019 assert!(
1020 msg.contains('\u{2026}'),
1021 "a clipped body must carry an ellipsis: {msg}"
1022 );
1023 }
1024
1025 #[test]
1026 fn http_upgrade_error_body_clips_on_char_boundary_not_mid_codepoint() {
1027 let body = format!(
1039 "{}{}",
1040 "a".repeat(HTTP_ERROR_BODY_MAX_CHARS - 1),
1041 "\u{4e16}".repeat(101)
1042 );
1043 let err = WsClientError::from(http_error(502, Some(body.as_bytes())));
1044 let WsClientError::Transport(msg) = err else {
1045 panic!("expected Transport, got {err:?}");
1046 };
1047 assert!(msg.contains("502"), "status must survive: {msg}");
1048 assert!(
1049 msg.contains('\u{2026}'),
1050 "an over-limit body must be clipped: {msg}"
1051 );
1052 assert!(
1053 msg.contains('\u{4e16}'),
1054 "the char straddling the clip point must survive whole: {msg}"
1055 );
1056 assert!(
1057 !msg.contains('\u{fffd}'),
1058 "no codepoint may be split (no replacement char): {msg}"
1059 );
1060 }
1061
1062 #[test]
1063 fn clip_error_body_keeps_an_exactly_at_limit_body_verbatim() {
1064 let at_limit = "x".repeat(HTTP_ERROR_BODY_MAX_CHARS);
1070 let clipped = clip_error_body(&at_limit);
1071 assert_eq!(
1072 clipped, at_limit,
1073 "a body exactly at the limit must be returned verbatim"
1074 );
1075 assert!(
1076 !clipped.contains('\u{2026}'),
1077 "an at-limit body must not gain an ellipsis: {clipped}"
1078 );
1079
1080 let over_limit = "x".repeat(HTTP_ERROR_BODY_MAX_CHARS + 1);
1081 let clipped = clip_error_body(&over_limit);
1082 assert_eq!(
1083 clipped.chars().count(),
1084 HTTP_ERROR_BODY_MAX_CHARS + 1,
1085 "an over-limit body keeps MAX chars plus the ellipsis"
1086 );
1087 assert!(
1088 clipped.ends_with('\u{2026}'),
1089 "an over-limit body must end with an ellipsis: {clipped}"
1090 );
1091 }
1092}