1use std::convert::TryFrom;
21use std::time::{Duration, Instant};
22
23use std::sync::Arc;
24
25use futures_util::stream::{SplitSink, SplitStream};
26use futures_util::{SinkExt, StreamExt};
27use tokio::net::TcpStream;
28use tokio::sync::Mutex;
29use tokio_tungstenite::tungstenite::client::IntoClientRequest;
30use tokio_tungstenite::tungstenite::http::HeaderValue;
31use tokio_tungstenite::tungstenite::http::StatusCode;
32use tokio_tungstenite::tungstenite::protocol::{frame::coding::CloseCode, CloseFrame};
33use tokio_tungstenite::tungstenite::{Error as TError, Message};
34use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
35use tracing::{debug, warn};
36use url::Url;
37
38use crate::ws::types::{WorkerInbound, WorkerOutbound};
39
40pub const SUBPROTOCOL: &str = "studio-worker-v1";
41
42const TRACE_TARGET: &str = "studio_worker::ws::client";
47const API_PREFIX: &str = "/graphics/api";
51
52const CONNECT_TIMEOUT: Duration = Duration::from_secs(15);
56
57pub type WsResult<T> = Result<T, WsClientError>;
59
60#[derive(Debug, thiserror::Error)]
63pub enum WsClientError {
64 #[error("auth failed: {reason}")]
66 AuthFailed { reason: String },
67
68 #[error("connection closed by server")]
71 ConnectionClosed,
72
73 #[error("ws transport error: {0}")]
75 Transport(String),
76
77 #[error("protocol error: {0}")]
79 Protocol(String),
80}
81
82impl From<TError> for WsClientError {
83 fn from(value: TError) -> Self {
84 match value {
85 TError::Http(response) if response.status() == StatusCode::UNAUTHORIZED => {
86 WsClientError::AuthFailed {
87 reason: "401 on websocket upgrade".to_string(),
88 }
89 }
90 TError::ConnectionClosed | TError::AlreadyClosed => WsClientError::ConnectionClosed,
91 other => WsClientError::Transport(other.to_string()),
92 }
93 }
94}
95
96fn build_connect_url(base_url: &str, worker_id: &str) -> WsResult<Url> {
98 let mut url = Url::parse(base_url)
99 .map_err(|e| WsClientError::Transport(format!("invalid base url: {e}")))?;
100 let new_scheme = match url.scheme() {
101 "http" => Some("ws"),
102 "https" => Some("wss"),
103 "ws" | "wss" => None, other => {
105 return Err(WsClientError::Transport(format!(
106 "unsupported scheme: {other}"
107 )))
108 }
109 };
110 if let Some(scheme) = new_scheme {
111 url.set_scheme(scheme)
112 .map_err(|_| WsClientError::Transport("set_scheme failed".to_string()))?;
113 }
114 let trimmed_path = url.path().trim_end_matches('/');
115 let prefixed = if trimmed_path.ends_with(API_PREFIX) {
119 trimmed_path.to_string()
120 } else {
121 format!("{trimmed_path}{API_PREFIX}")
122 };
123 let new_path = format!("{prefixed}/workers/{worker_id}/connect");
124 url.set_path(&new_path);
125 Ok(url)
126}
127
128pub async fn connect(base_url: &str, worker_id: &str, auth_token: &str) -> WsResult<WsClient> {
135 let started = Instant::now();
136 let result = connect_inner(base_url, worker_id, auth_token, CONNECT_TIMEOUT).await;
137 let elapsed_ms = started.elapsed().as_millis() as u64;
138 match &result {
139 Ok(_) => debug!(
140 target: TRACE_TARGET,
141 op = "connect",
142 worker_id,
143 elapsed_ms,
144 "websocket established"
145 ),
146 Err(e) => warn!(
147 target: TRACE_TARGET,
148 op = "connect",
149 worker_id,
150 elapsed_ms,
151 error = %e,
152 "websocket connect failed"
153 ),
154 }
155 result
156}
157
158async fn connect_inner(
159 base_url: &str,
160 worker_id: &str,
161 auth_token: &str,
162 connect_timeout: Duration,
163) -> WsResult<WsClient> {
164 let url = build_connect_url(base_url, worker_id)?;
165 debug!(
166 target: TRACE_TARGET,
167 op = "connect",
168 worker_id,
169 url = %url,
170 "opening websocket"
171 );
172 let mut request = url
173 .as_str()
174 .into_client_request()
175 .map_err(WsClientError::from)?;
176 let headers = request.headers_mut();
177 headers.insert(
178 "Authorization",
179 HeaderValue::try_from(format!("Bearer {auth_token}"))
180 .map_err(|e| WsClientError::Transport(format!("invalid auth header: {e}")))?,
181 );
182 headers.insert(
183 "Sec-WebSocket-Protocol",
184 HeaderValue::from_static(SUBPROTOCOL),
185 );
186
187 let (stream, _response) = match tokio::time::timeout(
188 connect_timeout,
189 tokio_tungstenite::connect_async(request),
190 )
191 .await
192 {
193 Ok(result) => result?,
194 Err(_elapsed) => {
195 return Err(WsClientError::Transport(format!(
196 "connect timed out after {connect_timeout:?}"
197 )))
198 }
199 };
200 let (sink, source) = stream.split();
201 Ok(WsClient {
202 sink,
203 source,
204 closed: false,
205 })
206}
207
208type WsSink = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
209type WsSource = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
210
211#[allow(missing_debug_implementations)]
214pub struct WsClient {
215 sink: WsSink,
216 source: WsSource,
217 closed: bool,
218}
219
220impl WsClient {
221 pub fn split(self) -> (WsSender, WsReceiver) {
226 let sink = Arc::new(Mutex::new(self.sink));
227 (
228 WsSender { sink },
229 WsReceiver {
230 source: self.source,
231 closed: false,
232 },
233 )
234 }
235}
236
237#[derive(Clone)]
241#[allow(missing_debug_implementations)]
242pub struct WsSender {
243 sink: Arc<Mutex<WsSink>>,
244}
245
246impl WsSender {
247 pub async fn send(&self, frame: &WorkerInbound) -> WsResult<()> {
248 let text = serde_json::to_string(frame).map_err(|e| {
249 let err = WsClientError::Protocol(e.to_string());
250 log_send_error(frame, &err);
251 err
252 })?;
253 let mut guard = self.sink.lock().await;
254 guard.send(Message::Text(text.into())).await.map_err(|e| {
255 let err = WsClientError::from(e);
256 log_send_error(frame, &err);
257 err
258 })
259 }
260
261 pub async fn close(&self, code: u16, reason: &str) -> WsResult<()> {
262 debug!(target: TRACE_TARGET, op = "close", code, reason, "closing websocket");
263 let frame = CloseFrame {
264 code: CloseCode::from(code),
265 reason: reason.to_owned().into(),
266 };
267 let mut guard = self.sink.lock().await;
268 if tokio::time::timeout(
269 Duration::from_secs(5),
270 guard.send(Message::Close(Some(frame))),
271 )
272 .await
273 .is_err()
274 {
275 warn!(target: TRACE_TARGET, op = "close", code, "timed out sending close frame");
276 }
277 Ok(())
278 }
279}
280
281#[allow(missing_debug_implementations)]
283pub struct WsReceiver {
284 source: WsSource,
285 closed: bool,
286}
287
288impl WsReceiver {
289 pub async fn recv(&mut self) -> WsResult<Option<WorkerOutbound>> {
293 if self.closed {
294 return Ok(None);
295 }
296 while let Some(item) = self.source.next().await {
297 match classify_incoming(item) {
298 RecvStep::Yield(frame) => return Ok(Some(frame)),
299 RecvStep::Skip => continue,
300 RecvStep::Fail(e) => return Err(e),
301 RecvStep::Closed(e) => {
302 self.closed = true;
303 return Err(e);
304 }
305 }
306 }
307 self.closed = true;
308 debug!(target: TRACE_TARGET, op = "recv", "stream ended (no close frame)");
309 Ok(None)
310 }
311}
312
313impl std::fmt::Debug for WsClient {
314 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315 f.debug_struct("WsClient")
316 .field("closed", &self.closed)
317 .finish()
318 }
319}
320
321impl WsClient {
322 pub async fn send(&mut self, frame: &WorkerInbound) -> WsResult<()> {
324 let text = serde_json::to_string(frame).map_err(|e| {
325 let err = WsClientError::Protocol(e.to_string());
326 log_send_error(frame, &err);
327 err
328 })?;
329 self.sink
330 .send(Message::Text(text.into()))
331 .await
332 .map_err(|e| {
333 let err = WsClientError::from(e);
334 log_send_error(frame, &err);
335 err
336 })
337 }
338
339 pub async fn recv(&mut self) -> WsResult<Option<WorkerOutbound>> {
344 if self.closed {
345 return Ok(None);
346 }
347 while let Some(item) = self.source.next().await {
348 match classify_incoming(item) {
349 RecvStep::Yield(frame) => return Ok(Some(frame)),
350 RecvStep::Skip => continue,
351 RecvStep::Fail(e) => return Err(e),
352 RecvStep::Closed(e) => {
353 self.closed = true;
354 return Err(e);
355 }
356 }
357 }
358 self.closed = true;
359 debug!(target: TRACE_TARGET, op = "recv", "stream ended (no close frame)");
360 Ok(None)
361 }
362
363 pub async fn close(&mut self, code: u16, reason: &str) -> WsResult<()> {
365 if self.closed {
366 return Ok(());
367 }
368 self.closed = true;
369 debug!(target: TRACE_TARGET, op = "close", code, reason, "closing websocket");
370 let frame = CloseFrame {
371 code: CloseCode::from(code),
372 reason: reason.to_owned().into(),
373 };
374 if tokio::time::timeout(
376 Duration::from_secs(5),
377 self.sink.send(Message::Close(Some(frame))),
378 )
379 .await
380 .is_err()
381 {
382 warn!(target: TRACE_TARGET, op = "close", code, "timed out sending close frame");
383 }
384 Ok(())
385 }
386}
387
388fn frame_label(frame: &WorkerInbound) -> &'static str {
392 match frame {
393 WorkerInbound::Hello(_) => "hello",
394 WorkerInbound::Heartbeat { .. } => "heartbeat",
395 WorkerInbound::Accept { .. } => "accept",
396 WorkerInbound::Reject { .. } => "reject",
397 WorkerInbound::CompleteJson { .. } => "completeJson",
398 WorkerInbound::Fail { .. } => "fail",
399 WorkerInbound::LogBatch { .. } => "logBatch",
400 WorkerInbound::ReadyForMore => "readyForMore",
401 }
402}
403
404fn log_send_error(frame: &WorkerInbound, err: &WsClientError) {
408 warn!(
409 target: TRACE_TARGET,
410 op = "send",
411 frame = frame_label(frame),
412 error = %err,
413 "failed to send frame"
414 );
415}
416
417enum RecvStep {
421 Yield(WorkerOutbound),
423 Skip,
425 Fail(WsClientError),
427 Closed(WsClientError),
429}
430
431fn classify_incoming(item: Result<Message, TError>) -> RecvStep {
434 match item {
435 Ok(Message::Text(text)) => match serde_json::from_str::<WorkerOutbound>(&text) {
436 Ok(frame) => RecvStep::Yield(frame),
437 Err(e) => {
438 warn!(
439 target: TRACE_TARGET,
440 op = "recv",
441 error = %e,
442 "dropping unparseable text frame"
443 );
444 RecvStep::Fail(WsClientError::Protocol(e.to_string()))
445 }
446 },
447 Ok(Message::Binary(_)) => {
448 warn!(
449 target: TRACE_TARGET,
450 op = "recv",
451 "rejecting unexpected binary frame"
452 );
453 RecvStep::Fail(WsClientError::Protocol(
454 "unexpected binary frame".to_string(),
455 ))
456 }
457 Ok(Message::Close(frame)) => {
458 let err = close_frame_to_error(frame);
459 match &err {
460 WsClientError::AuthFailed { reason } => warn!(
461 target: TRACE_TARGET,
462 op = "recv",
463 reason = %reason,
464 "server closed connection: auth failed"
465 ),
466 _ => debug!(
467 target: TRACE_TARGET,
468 op = "recv",
469 "server closed connection"
470 ),
471 }
472 RecvStep::Closed(err)
473 }
474 Ok(_) => RecvStep::Skip,
476 Err(e) => {
477 let mapped = WsClientError::from(e);
478 match &mapped {
479 WsClientError::ConnectionClosed => debug!(
483 target: TRACE_TARGET,
484 op = "recv",
485 "connection closed by peer"
486 ),
487 other => warn!(
488 target: TRACE_TARGET,
489 op = "recv",
490 error = %other,
491 "transport error while reading frame"
492 ),
493 }
494 RecvStep::Fail(mapped)
495 }
496 }
497}
498
499fn close_frame_to_error(frame: Option<CloseFrame>) -> WsClientError {
500 if let Some(frame) = frame {
501 let code: u16 = frame.code.into();
502 if code == 4001 {
503 return WsClientError::AuthFailed {
504 reason: format!("server closed 4001: {}", frame.reason),
505 };
506 }
507 }
508 WsClientError::ConnectionClosed
509}
510
511#[cfg(test)]
512mod tests {
513 use super::*;
514
515 #[test]
516 fn build_connect_url_http_to_ws() {
517 let url = build_connect_url("http://api.example/graphics/api", "w-1").unwrap();
518 assert_eq!(url.scheme(), "ws");
519 assert!(url.path().ends_with("/workers/w-1/connect"));
520 }
521
522 #[test]
523 fn build_connect_url_https_to_wss() {
524 let url = build_connect_url("https://api.example/graphics/api/", "w-2").unwrap();
525 assert_eq!(url.scheme(), "wss");
526 assert_eq!(url.path(), "/graphics/api/workers/w-2/connect");
527 }
528
529 #[test]
530 fn build_connect_url_appends_graphics_api_prefix_when_missing() {
531 let url = build_connect_url("http://localhost:9790", "w-3").unwrap();
532 assert_eq!(url.scheme(), "ws");
533 assert_eq!(url.path(), "/graphics/api/workers/w-3/connect");
534 }
535
536 #[test]
537 fn build_connect_url_preserves_existing_ws_scheme() {
538 let url = build_connect_url("ws://localhost:9790/x", "w").unwrap();
539 assert_eq!(url.scheme(), "ws");
540 }
541
542 #[test]
543 fn build_connect_url_rejects_unknown_scheme() {
544 let err = build_connect_url("ftp://nope/", "w").unwrap_err();
545 assert!(matches!(err, WsClientError::Transport(_)));
546 }
547
548 #[test]
549 fn build_connect_url_rejects_invalid_url() {
550 let err = build_connect_url("not a url", "w").unwrap_err();
551 assert!(matches!(err, WsClientError::Transport(_)));
552 }
553
554 #[test]
555 fn close_frame_4001_maps_to_auth_failed() {
556 let frame = CloseFrame {
557 code: CloseCode::Library(4001),
558 reason: "bad token".into(),
559 };
560 let err = close_frame_to_error(Some(frame));
561 assert!(matches!(err, WsClientError::AuthFailed { .. }));
562 }
563
564 #[test]
565 fn close_frame_other_codes_map_to_connection_closed() {
566 let frame = CloseFrame {
567 code: CloseCode::Normal,
568 reason: "bye".into(),
569 };
570 let err = close_frame_to_error(Some(frame));
571 assert!(matches!(err, WsClientError::ConnectionClosed));
572 }
573
574 #[test]
575 fn close_frame_missing_maps_to_connection_closed() {
576 let err = close_frame_to_error(None);
577 assert!(matches!(err, WsClientError::ConnectionClosed));
578 }
579
580 #[test]
581 fn transport_error_round_trips_through_from_impl() {
582 let inner = TError::AlreadyClosed;
583 let mapped: WsClientError = inner.into();
584 assert!(matches!(mapped, WsClientError::ConnectionClosed));
585 }
586
587 use crate::test_support::capture;
596
597 #[test]
598 fn classify_rejects_binary_frame_with_warn() {
599 let logs = capture(|| {
600 let step = classify_incoming(Ok(Message::Binary(vec![1, 2, 3].into())));
601 assert!(matches!(step, RecvStep::Fail(WsClientError::Protocol(_))));
602 });
603 assert!(logs.contains("WARN"), "expected WARN, got: {logs}");
604 assert!(
605 logs.contains("studio_worker::ws::client"),
606 "expected target, got: {logs}"
607 );
608 assert!(logs.contains("op=\"recv\""), "expected op field: {logs}");
609 assert!(logs.contains("binary"), "expected reason: {logs}");
610 }
611
612 #[test]
613 fn classify_warns_on_unparseable_text_frame() {
614 let logs = capture(|| {
615 let step = classify_incoming(Ok(Message::Text("not json".into())));
616 assert!(matches!(step, RecvStep::Fail(WsClientError::Protocol(_))));
617 });
618 assert!(logs.contains("WARN"), "expected WARN, got: {logs}");
619 assert!(logs.contains("op=\"recv\""), "expected op field: {logs}");
620 }
621
622 #[test]
623 fn classify_warns_on_4001_close_frame() {
624 let logs = capture(|| {
625 let frame = CloseFrame {
626 code: CloseCode::Library(4001),
627 reason: "invalid auth token".into(),
628 };
629 let step = classify_incoming(Ok(Message::Close(Some(frame))));
630 assert!(matches!(
631 step,
632 RecvStep::Closed(WsClientError::AuthFailed { .. })
633 ));
634 });
635 assert!(logs.contains("WARN"), "expected WARN, got: {logs}");
636 assert!(logs.contains("auth failed"), "expected reason: {logs}");
637 }
638
639 #[test]
640 fn classify_debug_logs_on_normal_close_frame() {
641 let logs = capture(|| {
642 let frame = CloseFrame {
643 code: CloseCode::Normal,
644 reason: "bye".into(),
645 };
646 let step = classify_incoming(Ok(Message::Close(Some(frame))));
647 assert!(matches!(
648 step,
649 RecvStep::Closed(WsClientError::ConnectionClosed)
650 ));
651 });
652 assert!(logs.contains("DEBUG"), "expected DEBUG, got: {logs}");
653 assert!(!logs.contains("WARN"), "normal close must not warn: {logs}");
654 assert!(logs.contains("server closed"), "expected message: {logs}");
655 }
656
657 #[test]
658 fn classify_yields_valid_frame_without_warning() {
659 let logs = capture(|| {
660 let json = serde_json::json!({ "type": "heartbeatAck" }).to_string();
661 let step = classify_incoming(Ok(Message::Text(json.into())));
662 assert!(matches!(
663 step,
664 RecvStep::Yield(WorkerOutbound::HeartbeatAck)
665 ));
666 });
667 assert!(
668 !logs.contains("WARN"),
669 "a valid frame should not warn: {logs}"
670 );
671 }
672
673 #[test]
674 fn classify_skips_control_frames() {
675 assert!(matches!(
676 classify_incoming(Ok(Message::Ping(Vec::new().into()))),
677 RecvStep::Skip
678 ));
679 assert!(matches!(
680 classify_incoming(Ok(Message::Pong(Vec::new().into()))),
681 RecvStep::Skip
682 ));
683 }
684
685 #[test]
686 fn frame_label_names_every_inbound_variant() {
687 use crate::types::WorkerCapabilities;
688 let caps = WorkerCapabilities {
689 machine_name: String::new(),
690 username: String::new(),
691 agent_version: String::new(),
692 engine: String::new(),
693 vram_total_gb: 0.0,
694 vram_threshold_gb: 0.0,
695 auto_enabled: false,
696 auto_start: false,
697 supported_models: vec![],
698 task_kinds: vec![],
699 supported_models_per_kind: Default::default(),
700 };
701 assert_eq!(
702 frame_label(&WorkerInbound::Hello(crate::ws::types::HelloFrame {
703 auth_token: String::new(),
704 capabilities: caps.clone(),
705 })),
706 "hello"
707 );
708 assert_eq!(
709 frame_label(&WorkerInbound::Heartbeat {
710 capabilities: caps,
711 current_job_id: None,
712 }),
713 "heartbeat"
714 );
715 assert_eq!(
716 frame_label(&WorkerInbound::Accept { job_id: "j".into() }),
717 "accept"
718 );
719 assert_eq!(
720 frame_label(&WorkerInbound::Reject {
721 job_id: "j".into(),
722 reason: "r".into(),
723 }),
724 "reject"
725 );
726 assert_eq!(
727 frame_label(&WorkerInbound::CompleteJson {
728 job_id: "j".into(),
729 result: serde_json::Value::Null,
730 prompt: None,
731 }),
732 "completeJson"
733 );
734 assert_eq!(
735 frame_label(&WorkerInbound::Fail {
736 job_id: "j".into(),
737 error: "e".into(),
738 retryable: true,
739 }),
740 "fail"
741 );
742 assert_eq!(
743 frame_label(&WorkerInbound::LogBatch { entries: vec![] }),
744 "logBatch"
745 );
746 assert_eq!(frame_label(&WorkerInbound::ReadyForMore), "readyForMore");
747 }
748
749 #[test]
750 fn send_error_logs_warn_with_frame_label() {
751 let logs = capture(|| {
752 log_send_error(
753 &WorkerInbound::Accept {
754 job_id: "j-1".into(),
755 },
756 &WsClientError::ConnectionClosed,
757 );
758 });
759 assert!(logs.contains("WARN"), "expected WARN, got: {logs}");
760 assert!(logs.contains("op=\"send\""), "expected op field: {logs}");
761 assert!(
762 logs.contains("frame=\"accept\""),
763 "expected frame label: {logs}"
764 );
765 }
766
767 #[tokio::test]
768 async fn connect_times_out_against_a_stalling_upgrade() {
769 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
772 let addr = listener.local_addr().unwrap();
773 tokio::spawn(async move {
774 let _accepted = listener.accept().await; tokio::time::sleep(Duration::from_secs(30)).await;
776 });
777 let url = format!("http://{addr}/graphics/api");
778 let started = Instant::now();
779 let result = connect_inner(&url, "w", "tok", Duration::from_millis(150)).await;
780 assert!(
781 matches!(result, Err(WsClientError::Transport(_))),
782 "expected a transport timeout, got {result:?}"
783 );
784 assert!(
785 started.elapsed() < Duration::from_secs(2),
786 "connect must time out promptly, took {:?}",
787 started.elapsed()
788 );
789 }
790
791 #[test]
792 fn connect_failure_logs_warn_breadcrumb() {
793 let logs = capture(|| {
796 let rt = tokio::runtime::Builder::new_current_thread()
797 .enable_all()
798 .build()
799 .unwrap();
800 let result = rt.block_on(connect("http://127.0.0.1:1/graphics/api", "w-err", "tok"));
801 assert!(result.is_err(), "connect to a dead port should fail");
802 });
803 assert!(logs.contains("WARN"), "expected WARN, got: {logs}");
804 assert!(logs.contains("op=\"connect\""), "expected op field: {logs}");
805 assert!(
806 logs.contains("websocket connect failed"),
807 "expected message: {logs}"
808 );
809 assert!(
810 logs.contains("worker_id=\"w-err\""),
811 "expected worker_id field: {logs}"
812 );
813 }
814}