1use std::convert::TryFrom;
14use std::time::Duration;
15
16use std::sync::Arc;
17
18use futures_util::stream::{SplitSink, SplitStream};
19use futures_util::{SinkExt, StreamExt};
20use tokio::net::TcpStream;
21use tokio::sync::Mutex;
22use tokio_tungstenite::tungstenite::client::IntoClientRequest;
23use tokio_tungstenite::tungstenite::http::HeaderValue;
24use tokio_tungstenite::tungstenite::http::StatusCode;
25use tokio_tungstenite::tungstenite::protocol::{frame::coding::CloseCode, CloseFrame};
26use tokio_tungstenite::tungstenite::{Error as TError, Message};
27use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
28use url::Url;
29
30use crate::ws::types::{WorkerInbound, WorkerOutbound};
31
32pub const SUBPROTOCOL: &str = "studio-worker-v1";
33const API_PREFIX: &str = "/graphics/api";
37
38pub type WsResult<T> = Result<T, WsClientError>;
40
41#[derive(Debug, thiserror::Error)]
44pub enum WsClientError {
45 #[error("auth failed: {reason}")]
47 AuthFailed { reason: String },
48
49 #[error("connection closed by server")]
52 ConnectionClosed,
53
54 #[error("ws transport error: {0}")]
56 Transport(String),
57
58 #[error("protocol error: {0}")]
60 Protocol(String),
61}
62
63impl From<TError> for WsClientError {
64 fn from(value: TError) -> Self {
65 match value {
66 TError::Http(response) if response.status() == StatusCode::UNAUTHORIZED => {
67 WsClientError::AuthFailed {
68 reason: "401 on websocket upgrade".to_string(),
69 }
70 }
71 TError::ConnectionClosed | TError::AlreadyClosed => WsClientError::ConnectionClosed,
72 other => WsClientError::Transport(other.to_string()),
73 }
74 }
75}
76
77fn build_connect_url(base_url: &str, worker_id: &str) -> WsResult<Url> {
79 let mut url = Url::parse(base_url)
80 .map_err(|e| WsClientError::Transport(format!("invalid base url: {e}")))?;
81 let new_scheme = match url.scheme() {
82 "http" => Some("ws"),
83 "https" => Some("wss"),
84 "ws" | "wss" => None, other => {
86 return Err(WsClientError::Transport(format!(
87 "unsupported scheme: {other}"
88 )))
89 }
90 };
91 if let Some(scheme) = new_scheme {
92 url.set_scheme(scheme)
93 .map_err(|_| WsClientError::Transport("set_scheme failed".to_string()))?;
94 }
95 let trimmed_path = url.path().trim_end_matches('/');
96 let prefixed = if trimmed_path.ends_with(API_PREFIX) {
100 trimmed_path.to_string()
101 } else {
102 format!("{trimmed_path}{API_PREFIX}")
103 };
104 let new_path = format!("{prefixed}/workers/{worker_id}/connect");
105 url.set_path(&new_path);
106 Ok(url)
107}
108
109pub async fn connect(base_url: &str, worker_id: &str, auth_token: &str) -> WsResult<WsClient> {
112 let url = build_connect_url(base_url, worker_id)?;
113 let mut request = url
114 .as_str()
115 .into_client_request()
116 .map_err(WsClientError::from)?;
117 let headers = request.headers_mut();
118 headers.insert(
119 "Authorization",
120 HeaderValue::try_from(format!("Bearer {auth_token}"))
121 .map_err(|e| WsClientError::Transport(format!("invalid auth header: {e}")))?,
122 );
123 headers.insert(
124 "Sec-WebSocket-Protocol",
125 HeaderValue::from_static(SUBPROTOCOL),
126 );
127
128 let (stream, _response) = tokio_tungstenite::connect_async(request).await?;
129 let (sink, source) = stream.split();
130 Ok(WsClient {
131 sink,
132 source,
133 closed: false,
134 })
135}
136
137type WsSink = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
138type WsSource = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
139
140#[allow(missing_debug_implementations)]
143pub struct WsClient {
144 sink: WsSink,
145 source: WsSource,
146 closed: bool,
147}
148
149impl WsClient {
150 pub fn split(self) -> (WsSender, WsReceiver) {
155 let sink = Arc::new(Mutex::new(self.sink));
156 (
157 WsSender { sink },
158 WsReceiver {
159 source: self.source,
160 closed: false,
161 },
162 )
163 }
164}
165
166#[derive(Clone)]
170#[allow(missing_debug_implementations)]
171pub struct WsSender {
172 sink: Arc<Mutex<WsSink>>,
173}
174
175impl WsSender {
176 pub async fn send(&self, frame: &WorkerInbound) -> WsResult<()> {
177 let text =
178 serde_json::to_string(frame).map_err(|e| WsClientError::Protocol(e.to_string()))?;
179 let mut guard = self.sink.lock().await;
180 guard
181 .send(Message::Text(text.into()))
182 .await
183 .map_err(WsClientError::from)
184 }
185
186 pub async fn close(&self, code: u16, reason: &str) -> WsResult<()> {
187 let frame = CloseFrame {
188 code: CloseCode::from(code),
189 reason: reason.to_owned().into(),
190 };
191 let mut guard = self.sink.lock().await;
192 let _ = tokio::time::timeout(
193 Duration::from_secs(5),
194 guard.send(Message::Close(Some(frame))),
195 )
196 .await;
197 Ok(())
198 }
199}
200
201#[allow(missing_debug_implementations)]
203pub struct WsReceiver {
204 source: WsSource,
205 closed: bool,
206}
207
208impl WsReceiver {
209 pub async fn recv(&mut self) -> WsResult<Option<WorkerOutbound>> {
213 if self.closed {
214 return Ok(None);
215 }
216 while let Some(item) = self.source.next().await {
217 match item {
218 Ok(Message::Text(text)) => {
219 let frame: WorkerOutbound = serde_json::from_str(&text)
220 .map_err(|e| WsClientError::Protocol(e.to_string()))?;
221 return Ok(Some(frame));
222 }
223 Ok(Message::Binary(_)) => {
224 return Err(WsClientError::Protocol(
225 "unexpected binary frame".to_string(),
226 ));
227 }
228 Ok(Message::Close(frame)) => {
229 self.closed = true;
230 return Err(close_frame_to_error(frame));
231 }
232 Ok(_) => continue,
233 Err(e) => return Err(WsClientError::from(e)),
234 }
235 }
236 self.closed = true;
237 Ok(None)
238 }
239}
240
241impl std::fmt::Debug for WsClient {
242 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
243 f.debug_struct("WsClient")
244 .field("closed", &self.closed)
245 .finish()
246 }
247}
248
249impl WsClient {
250 pub async fn send(&mut self, frame: &WorkerInbound) -> WsResult<()> {
252 let text =
253 serde_json::to_string(frame).map_err(|e| WsClientError::Protocol(e.to_string()))?;
254 self.sink
255 .send(Message::Text(text.into()))
256 .await
257 .map_err(WsClientError::from)
258 }
259
260 pub async fn recv(&mut self) -> WsResult<Option<WorkerOutbound>> {
265 if self.closed {
266 return Ok(None);
267 }
268 while let Some(item) = self.source.next().await {
269 match item {
270 Ok(Message::Text(text)) => {
271 let frame: WorkerOutbound = serde_json::from_str(&text)
272 .map_err(|e| WsClientError::Protocol(e.to_string()))?;
273 return Ok(Some(frame));
274 }
275 Ok(Message::Binary(_)) => {
276 return Err(WsClientError::Protocol(
277 "unexpected binary frame".to_string(),
278 ));
279 }
280 Ok(Message::Close(frame)) => {
281 self.closed = true;
282 return Err(close_frame_to_error(frame));
283 }
284 Ok(_) => continue, Err(e) => return Err(WsClientError::from(e)),
286 }
287 }
288 self.closed = true;
289 Ok(None)
290 }
291
292 pub async fn close(&mut self, code: u16, reason: &str) -> WsResult<()> {
294 if self.closed {
295 return Ok(());
296 }
297 self.closed = true;
298 let frame = CloseFrame {
299 code: CloseCode::from(code),
300 reason: reason.to_owned().into(),
301 };
302 let _ = tokio::time::timeout(
304 Duration::from_secs(5),
305 self.sink.send(Message::Close(Some(frame))),
306 )
307 .await;
308 Ok(())
309 }
310}
311
312fn close_frame_to_error(frame: Option<CloseFrame>) -> WsClientError {
313 if let Some(frame) = frame {
314 let code: u16 = frame.code.into();
315 if code == 4001 {
316 return WsClientError::AuthFailed {
317 reason: format!("server closed 4001: {}", frame.reason),
318 };
319 }
320 }
321 WsClientError::ConnectionClosed
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327
328 #[test]
329 fn build_connect_url_http_to_ws() {
330 let url = build_connect_url("http://api.example/graphics/api", "w-1").unwrap();
331 assert_eq!(url.scheme(), "ws");
332 assert!(url.path().ends_with("/workers/w-1/connect"));
333 }
334
335 #[test]
336 fn build_connect_url_https_to_wss() {
337 let url = build_connect_url("https://api.example/graphics/api/", "w-2").unwrap();
338 assert_eq!(url.scheme(), "wss");
339 assert_eq!(url.path(), "/graphics/api/workers/w-2/connect");
340 }
341
342 #[test]
343 fn build_connect_url_appends_graphics_api_prefix_when_missing() {
344 let url = build_connect_url("http://localhost:9790", "w-3").unwrap();
345 assert_eq!(url.scheme(), "ws");
346 assert_eq!(url.path(), "/graphics/api/workers/w-3/connect");
347 }
348
349 #[test]
350 fn build_connect_url_preserves_existing_ws_scheme() {
351 let url = build_connect_url("ws://localhost:9790/x", "w").unwrap();
352 assert_eq!(url.scheme(), "ws");
353 }
354
355 #[test]
356 fn build_connect_url_rejects_unknown_scheme() {
357 let err = build_connect_url("ftp://nope/", "w").unwrap_err();
358 assert!(matches!(err, WsClientError::Transport(_)));
359 }
360
361 #[test]
362 fn build_connect_url_rejects_invalid_url() {
363 let err = build_connect_url("not a url", "w").unwrap_err();
364 assert!(matches!(err, WsClientError::Transport(_)));
365 }
366
367 #[test]
368 fn close_frame_4001_maps_to_auth_failed() {
369 let frame = CloseFrame {
370 code: CloseCode::Library(4001),
371 reason: "bad token".into(),
372 };
373 let err = close_frame_to_error(Some(frame));
374 assert!(matches!(err, WsClientError::AuthFailed { .. }));
375 }
376
377 #[test]
378 fn close_frame_other_codes_map_to_connection_closed() {
379 let frame = CloseFrame {
380 code: CloseCode::Normal,
381 reason: "bye".into(),
382 };
383 let err = close_frame_to_error(Some(frame));
384 assert!(matches!(err, WsClientError::ConnectionClosed));
385 }
386
387 #[test]
388 fn close_frame_missing_maps_to_connection_closed() {
389 let err = close_frame_to_error(None);
390 assert!(matches!(err, WsClientError::ConnectionClosed));
391 }
392
393 #[test]
394 fn transport_error_round_trips_through_from_impl() {
395 let inner = TError::AlreadyClosed;
396 let mapped: WsClientError = inner.into();
397 assert!(matches!(mapped, WsClientError::ConnectionClosed));
398 }
399}