haystack_client/transport/
ws.rs1use std::sync::Arc;
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::time::Duration;
4
5use dashmap::DashMap;
6use futures_util::{SinkExt, StreamExt};
7use tokio::sync::{Mutex, oneshot};
8use tokio_tungstenite::{connect_async, tungstenite};
9
10use crate::error::ClientError;
11use crate::transport::Transport;
12use haystack_core::codecs::codec_for;
13use haystack_core::data::HGrid;
14use haystack_core::kinds::Kind;
15
16type WsStream =
17 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
18
19const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
21
22const MAX_PENDING_REQUESTS: usize = 1024;
24
25pub struct WsTransport {
34 writer: Mutex<futures_util::stream::SplitSink<WsStream, tungstenite::Message>>,
35 pending: Arc<DashMap<u64, oneshot::Sender<Result<HGrid, ClientError>>>>,
36 next_id: AtomicU64,
37 request_timeout: Duration,
39 _reader_handle: tokio::task::JoinHandle<()>,
41 shutdown: tokio_util::sync::CancellationToken,
43}
44
45impl WsTransport {
46 pub async fn connect(url: &str, auth_token: &str) -> Result<Self, ClientError> {
51 let request = tungstenite::http::Request::builder()
52 .uri(url)
53 .header("Authorization", format!("BEARER authToken={}", auth_token))
54 .header(
55 "Sec-WebSocket-Key",
56 tungstenite::handshake::client::generate_key(),
57 )
58 .header("Sec-WebSocket-Version", "13")
59 .header("Connection", "Upgrade")
60 .header("Upgrade", "websocket")
61 .header("Host", extract_host(url).unwrap_or_default())
62 .body(())
63 .map_err(|e| ClientError::Transport(e.to_string()))?;
64
65 let (ws_stream, _response) =
66 tokio::time::timeout(Duration::from_secs(15), connect_async(request))
67 .await
68 .map_err(|_| ClientError::Transport("WebSocket connect timed out".to_string()))?
69 .map_err(|e| ClientError::Transport(format!("WebSocket connect failed: {}", e)))?;
70
71 let (writer, reader) = ws_stream.split();
72 let pending: Arc<DashMap<u64, oneshot::Sender<Result<HGrid, ClientError>>>> =
73 Arc::new(DashMap::new());
74
75 let shutdown = tokio_util::sync::CancellationToken::new();
76 let reader_handle = spawn_reader_task(reader, Arc::clone(&pending), shutdown.child_token());
77
78 Ok(Self {
79 writer: Mutex::new(writer),
80 pending,
81 next_id: AtomicU64::new(1),
82 request_timeout: DEFAULT_REQUEST_TIMEOUT,
83 _reader_handle: reader_handle,
84 shutdown,
85 })
86 }
87
88 pub async fn connect_with_timeout(
90 url: &str,
91 auth_token: &str,
92 timeout: Duration,
93 ) -> Result<Self, ClientError> {
94 let mut transport = Self::connect(url, auth_token).await?;
95 transport.request_timeout = timeout;
96 Ok(transport)
97 }
98}
99
100fn spawn_reader_task(
103 mut reader: futures_util::stream::SplitStream<WsStream>,
104 pending: Arc<DashMap<u64, oneshot::Sender<Result<HGrid, ClientError>>>>,
105 shutdown: tokio_util::sync::CancellationToken,
106) -> tokio::task::JoinHandle<()> {
107 tokio::spawn(async move {
108 let codec = codec_for("text/zinc");
109
110 loop {
111 tokio::select! {
112 _ = shutdown.cancelled() => {
113 drain_pending(&pending, ClientError::ConnectionClosed);
114 break;
115 }
116 msg = reader.next() => {
117 let Some(msg) = msg else { break };
118 match msg {
119 Ok(tungstenite::Message::Text(text)) => {
120 handle_text_message(&text, codec, &pending);
121 }
122 Ok(tungstenite::Message::Binary(data)) => {
123 if let Ok(decompressed) = decompress_deflate(&data) {
125 handle_text_message(&decompressed, codec, &pending);
126 }
127 }
128 Ok(tungstenite::Message::Close(_)) => {
129 drain_pending(&pending, ClientError::ConnectionClosed);
130 break;
131 }
132 Err(e) => {
133 drain_pending(&pending, ClientError::Transport(e.to_string()));
134 break;
135 }
136 _ => continue, }
138 }
139 }
140 }
141 })
142}
143
144fn handle_text_message(
146 text: &str,
147 codec: Option<&'static dyn haystack_core::codecs::Codec>,
148 pending: &DashMap<u64, oneshot::Sender<Result<HGrid, ClientError>>>,
149) {
150 let resp: serde_json::Value = match serde_json::from_str(text) {
151 Ok(v) => v,
152 Err(_) => return,
153 };
154
155 let resp_id: u64 = match resp.get("id").and_then(|v| {
156 v.as_str()
157 .and_then(|s| s.parse().ok())
158 .or_else(|| v.as_u64())
159 }) {
160 Some(id) => id,
161 None => return,
162 };
163
164 let result = match (codec, resp.get("body").and_then(|v| v.as_str())) {
165 (Some(c), Some(body)) => match c.decode_grid(body) {
166 Ok(grid) => {
167 if grid.is_err() {
168 let dis = grid
169 .meta
170 .get("dis")
171 .and_then(|k| {
172 if let Kind::Str(s) = k {
173 Some(s.as_str())
174 } else {
175 None
176 }
177 })
178 .unwrap_or("unknown server error");
179 Err(ClientError::ServerError(dis.to_string()))
180 } else {
181 Ok(grid)
182 }
183 }
184 Err(e) => Err(ClientError::Codec(e.to_string())),
185 },
186 _ => Err(ClientError::Codec(
187 "response missing 'body' field".to_string(),
188 )),
189 };
190
191 if let Some((_, sender)) = pending.remove(&resp_id) {
192 let _ = sender.send(result);
193 }
194}
195
196fn drain_pending(
198 pending: &DashMap<u64, oneshot::Sender<Result<HGrid, ClientError>>>,
199 error: ClientError,
200) {
201 let keys: Vec<u64> = pending.iter().map(|r| *r.key()).collect();
202 for key in keys {
203 if let Some((_, sender)) = pending.remove(&key) {
204 let _ = sender.send(Err(ClientError::Transport(error.to_string())));
205 }
206 }
207}
208
209fn compress_deflate(data: &[u8]) -> Vec<u8> {
211 use flate2::Compression;
212 use flate2::write::DeflateEncoder;
213 use std::io::Write;
214
215 let mut encoder = DeflateEncoder::new(Vec::new(), Compression::fast());
216 let _ = encoder.write_all(data);
217 encoder.finish().unwrap_or_else(|_| data.to_vec())
218}
219
220const MAX_DECOMPRESSED_SIZE: u64 = 10 * 1024 * 1024;
222
223fn decompress_deflate(data: &[u8]) -> Result<String, std::io::Error> {
225 use flate2::read::DeflateDecoder;
226 use std::io::Read;
227
228 let decoder = DeflateDecoder::new(data);
229 let mut limited = decoder.take(MAX_DECOMPRESSED_SIZE);
230 let mut output = String::new();
231 limited.read_to_string(&mut output)?;
232 Ok(output)
233}
234
235const COMPRESSION_THRESHOLD: usize = 512;
237
238fn extract_host(url: &str) -> Option<String> {
240 let parsed = url::Url::parse(url).ok()?;
241 let host = parsed.host_str()?.to_string();
242 match parsed.port() {
243 Some(port) => Some(format!("{}:{}", host, port)),
244 None => Some(host),
245 }
246}
247
248impl Transport for WsTransport {
249 async fn call(&self, op: &str, req: &HGrid) -> Result<HGrid, ClientError> {
250 if self.pending.len() >= MAX_PENDING_REQUESTS {
252 return Err(ClientError::TooManyRequests);
253 }
254
255 let codec = codec_for("text/zinc")
256 .ok_or_else(|| ClientError::Codec("zinc codec not available".to_string()))?;
257
258 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
259
260 let body = codec
261 .encode_grid(req)
262 .map_err(|e| ClientError::Codec(e.to_string()))?;
263
264 let envelope = serde_json::json!({
265 "id": id.to_string(),
266 "op": op,
267 "body": body,
268 });
269
270 let msg_text =
271 serde_json::to_string(&envelope).map_err(|e| ClientError::Codec(e.to_string()))?;
272
273 let ws_msg = if msg_text.len() >= COMPRESSION_THRESHOLD {
275 let compressed = compress_deflate(msg_text.as_bytes());
276 if compressed.len() < msg_text.len() {
277 tungstenite::Message::Binary(compressed.into())
278 } else {
279 tungstenite::Message::Text(msg_text.into())
280 }
281 } else {
282 tungstenite::Message::Text(msg_text.into())
283 };
284
285 let (tx, rx) = oneshot::channel();
287 self.pending.insert(id, tx);
288
289 {
291 let mut writer = self.writer.lock().await;
292 if let Err(e) = writer.send(ws_msg).await {
293 self.pending.remove(&id);
294 return Err(ClientError::Transport(e.to_string()));
295 }
296 }
297
298 let timeout = self.request_timeout;
300 match tokio::time::timeout(timeout, rx).await {
301 Ok(Ok(result)) => result,
302 Ok(Err(_)) => Err(ClientError::Transport(
303 "response channel closed unexpectedly".to_string(),
304 )),
305 Err(_) => {
306 self.pending.remove(&id);
307 Err(ClientError::Timeout(timeout))
308 }
309 }
310 }
311
312 async fn close(&self) -> Result<(), ClientError> {
313 self.shutdown.cancel();
314 let mut writer = self.writer.lock().await;
315 writer
316 .send(tungstenite::Message::Close(None))
317 .await
318 .map_err(|e| ClientError::Transport(e.to_string()))?;
319 Ok(())
320 }
321}
322
323impl Drop for WsTransport {
324 fn drop(&mut self) {
325 self.shutdown.cancel();
326 }
327}
328
329const INITIAL_BACKOFF: Duration = Duration::from_millis(250);
335const MAX_BACKOFF: Duration = Duration::from_secs(30);
337const MAX_RECONNECT_ATTEMPTS: u32 = 10;
339
340pub struct ReconnectingWsTransport {
346 url: String,
347 auth_token: zeroize::Zeroizing<String>,
348 request_timeout: Duration,
349 inner: Mutex<Option<Arc<WsTransport>>>,
350}
351
352impl ReconnectingWsTransport {
353 pub async fn connect(url: &str, auth_token: &str) -> Result<Self, ClientError> {
356 let transport = WsTransport::connect(url, auth_token).await?;
357 Ok(Self {
358 url: url.to_string(),
359 auth_token: zeroize::Zeroizing::new(auth_token.to_string()),
360 request_timeout: DEFAULT_REQUEST_TIMEOUT,
361 inner: Mutex::new(Some(Arc::new(transport))),
362 })
363 }
364
365 pub async fn connect_with_timeout(
367 url: &str,
368 auth_token: &str,
369 timeout: Duration,
370 ) -> Result<Self, ClientError> {
371 let transport = WsTransport::connect_with_timeout(url, auth_token, timeout).await?;
372 Ok(Self {
373 url: url.to_string(),
374 auth_token: zeroize::Zeroizing::new(auth_token.to_string()),
375 request_timeout: timeout,
376 inner: Mutex::new(Some(Arc::new(transport))),
377 })
378 }
379
380 async fn reconnect(&self) -> Result<(), ClientError> {
384 use rand::RngExt;
385
386 let mut backoff = INITIAL_BACKOFF;
387
388 for attempt in 1..=MAX_RECONNECT_ATTEMPTS {
389 let jitter_range = backoff.as_millis() as u64 / 4;
391 let jitter = if jitter_range > 0 {
392 let offset = rand::rng().random_range(0..jitter_range * 2);
393 Duration::from_millis(offset)
394 } else {
395 Duration::ZERO
396 };
397 let delay = backoff
398 .saturating_add(jitter)
399 .saturating_sub(Duration::from_millis(jitter_range));
400 tokio::time::sleep(delay).await;
401
402 match WsTransport::connect_with_timeout(
403 &self.url,
404 &self.auth_token,
405 self.request_timeout,
406 )
407 .await
408 {
409 Ok(transport) => {
410 *self.inner.lock().await = Some(Arc::new(transport));
411 return Ok(());
412 }
413 Err(_) if attempt < MAX_RECONNECT_ATTEMPTS => {
414 backoff = (backoff * 2).min(MAX_BACKOFF);
415 continue;
416 }
417 Err(e) => {
418 return Err(ClientError::Transport(format!(
419 "reconnection failed after {MAX_RECONNECT_ATTEMPTS} attempts: {e}"
420 )));
421 }
422 }
423 }
424
425 Err(ClientError::Transport(
426 "reconnection failed: max attempts exhausted".to_string(),
427 ))
428 }
429}
430
431impl Transport for ReconnectingWsTransport {
432 async fn call(&self, op: &str, req: &HGrid) -> Result<HGrid, ClientError> {
433 let transport = {
435 let guard = self.inner.lock().await;
436 guard.as_ref().cloned()
437 };
438 if let Some(transport) = transport {
439 match transport.call(op, req).await {
440 Ok(grid) => return Ok(grid),
441 Err(ClientError::Timeout(d)) => return Err(ClientError::Timeout(d)),
442 Err(ClientError::ServerError(e)) => return Err(ClientError::ServerError(e)),
443 Err(ClientError::TooManyRequests) => {
444 return Err(ClientError::TooManyRequests);
445 }
446 Err(_) => {
447 }
449 }
450 }
451
452 *self.inner.lock().await = None;
454 self.reconnect().await?;
455
456 let transport = {
458 let guard = self.inner.lock().await;
459 guard.as_ref().cloned()
460 };
461 match transport {
462 Some(transport) => transport.call(op, req).await,
463 None => Err(ClientError::ConnectionClosed),
464 }
465 }
466
467 async fn close(&self) -> Result<(), ClientError> {
468 let transport = self.inner.lock().await.take();
469 if let Some(transport) = transport {
470 transport.close().await
471 } else {
472 Ok(())
473 }
474 }
475}