Skip to main content

haystack_client/transport/
ws.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::time::Duration;
4
5use dashmap::DashMap;
6use futures_util::{SinkExt, StreamExt};
7use tokio::sync::{Mutex, oneshot};
8use tokio_tungstenite::{connect_async, tungstenite};
9
10use crate::error::ClientError;
11use crate::transport::Transport;
12use haystack_core::codecs::codec_for;
13use haystack_core::data::HGrid;
14use haystack_core::kinds::Kind;
15
16type WsStream =
17    tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
18
19/// Default timeout for a single WS request-response round-trip.
20const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
21
22/// Maximum number of concurrent in-flight requests.
23const MAX_PENDING_REQUESTS: usize = 1024;
24
25/// WebSocket transport for communicating with a Haystack server.
26///
27/// Uses a JSON envelope with Zinc-encoded grid bodies:
28/// - Request:  `{"id": "<counter>", "op": "<op_name>", "body": "<zinc_grid>"}`
29/// - Response: `{"id": "<counter>", "body": "<zinc_grid>"}`
30///
31/// Supports concurrent in-flight requests by matching response IDs to pending
32/// oneshot channels via a background reader task.
33pub struct WsTransport {
34    writer: Mutex<futures_util::stream::SplitSink<WsStream, tungstenite::Message>>,
35    pending: Arc<DashMap<u64, oneshot::Sender<Result<HGrid, ClientError>>>>,
36    next_id: AtomicU64,
37    /// Per-request timeout duration.
38    request_timeout: Duration,
39    /// Handle to the background reader task (kept alive for the transport's lifetime).
40    _reader_handle: tokio::task::JoinHandle<()>,
41    /// Cancellation token for graceful shutdown of the reader task.
42    shutdown: tokio_util::sync::CancellationToken,
43}
44
45impl WsTransport {
46    /// Connect to a Haystack server via WebSocket.
47    ///
48    /// `url` should be a `ws://` or `wss://` URL to the server's WebSocket endpoint.
49    /// `auth_token` is the bearer token obtained from SCRAM authentication.
50    pub async fn connect(url: &str, auth_token: &str) -> Result<Self, ClientError> {
51        let request = tungstenite::http::Request::builder()
52            .uri(url)
53            .header("Authorization", format!("BEARER authToken={}", auth_token))
54            .header(
55                "Sec-WebSocket-Key",
56                tungstenite::handshake::client::generate_key(),
57            )
58            .header("Sec-WebSocket-Version", "13")
59            .header("Connection", "Upgrade")
60            .header("Upgrade", "websocket")
61            .header("Host", extract_host(url).unwrap_or_default())
62            .body(())
63            .map_err(|e| ClientError::Transport(e.to_string()))?;
64
65        let (ws_stream, _response) =
66            tokio::time::timeout(Duration::from_secs(15), connect_async(request))
67                .await
68                .map_err(|_| ClientError::Transport("WebSocket connect timed out".to_string()))?
69                .map_err(|e| ClientError::Transport(format!("WebSocket connect failed: {}", e)))?;
70
71        let (writer, reader) = ws_stream.split();
72        let pending: Arc<DashMap<u64, oneshot::Sender<Result<HGrid, ClientError>>>> =
73            Arc::new(DashMap::new());
74
75        let shutdown = tokio_util::sync::CancellationToken::new();
76        let reader_handle = spawn_reader_task(reader, Arc::clone(&pending), shutdown.child_token());
77
78        Ok(Self {
79            writer: Mutex::new(writer),
80            pending,
81            next_id: AtomicU64::new(1),
82            request_timeout: DEFAULT_REQUEST_TIMEOUT,
83            _reader_handle: reader_handle,
84            shutdown,
85        })
86    }
87
88    /// Connect with a custom request timeout.
89    pub async fn connect_with_timeout(
90        url: &str,
91        auth_token: &str,
92        timeout: Duration,
93    ) -> Result<Self, ClientError> {
94        let mut transport = Self::connect(url, auth_token).await?;
95        transport.request_timeout = timeout;
96        Ok(transport)
97    }
98}
99
100/// Spawn a background task that reads WS messages and dispatches responses
101/// to the appropriate pending oneshot channel by matching the response `id`.
102fn spawn_reader_task(
103    mut reader: futures_util::stream::SplitStream<WsStream>,
104    pending: Arc<DashMap<u64, oneshot::Sender<Result<HGrid, ClientError>>>>,
105    shutdown: tokio_util::sync::CancellationToken,
106) -> tokio::task::JoinHandle<()> {
107    tokio::spawn(async move {
108        let codec = codec_for("text/zinc");
109
110        loop {
111            tokio::select! {
112                _ = shutdown.cancelled() => {
113                    drain_pending(&pending, ClientError::ConnectionClosed);
114                    break;
115                }
116                msg = reader.next() => {
117                    let Some(msg) = msg else { break };
118                    match msg {
119                        Ok(tungstenite::Message::Text(text)) => {
120                            handle_text_message(&text, codec, &pending);
121                        }
122                        Ok(tungstenite::Message::Binary(data)) => {
123                            // Compressed message: deflate-compressed JSON envelope.
124                            if let Ok(decompressed) = decompress_deflate(&data) {
125                                handle_text_message(&decompressed, codec, &pending);
126                            }
127                        }
128                        Ok(tungstenite::Message::Close(_)) => {
129                            drain_pending(&pending, ClientError::ConnectionClosed);
130                            break;
131                        }
132                        Err(e) => {
133                            drain_pending(&pending, ClientError::Transport(e.to_string()));
134                            break;
135                        }
136                        _ => continue, // ping/pong handled by tungstenite
137                    }
138                }
139            }
140        }
141    })
142}
143
144/// Process a text (or decompressed) JSON envelope and dispatch to the pending channel.
145fn handle_text_message(
146    text: &str,
147    codec: Option<&'static dyn haystack_core::codecs::Codec>,
148    pending: &DashMap<u64, oneshot::Sender<Result<HGrid, ClientError>>>,
149) {
150    let resp: serde_json::Value = match serde_json::from_str(text) {
151        Ok(v) => v,
152        Err(_) => return,
153    };
154
155    let resp_id: u64 = match resp.get("id").and_then(|v| {
156        v.as_str()
157            .and_then(|s| s.parse().ok())
158            .or_else(|| v.as_u64())
159    }) {
160        Some(id) => id,
161        None => return,
162    };
163
164    let result = match (codec, resp.get("body").and_then(|v| v.as_str())) {
165        (Some(c), Some(body)) => match c.decode_grid(body) {
166            Ok(grid) => {
167                if grid.is_err() {
168                    let dis = grid
169                        .meta
170                        .get("dis")
171                        .and_then(|k| {
172                            if let Kind::Str(s) = k {
173                                Some(s.as_str())
174                            } else {
175                                None
176                            }
177                        })
178                        .unwrap_or("unknown server error");
179                    Err(ClientError::ServerError(dis.to_string()))
180                } else {
181                    Ok(grid)
182                }
183            }
184            Err(e) => Err(ClientError::Codec(e.to_string())),
185        },
186        _ => Err(ClientError::Codec(
187            "response missing 'body' field".to_string(),
188        )),
189    };
190
191    if let Some((_, sender)) = pending.remove(&resp_id) {
192        let _ = sender.send(result);
193    }
194}
195
196/// Notify all pending requests with the given error and clear the map.
197fn drain_pending(
198    pending: &DashMap<u64, oneshot::Sender<Result<HGrid, ClientError>>>,
199    error: ClientError,
200) {
201    let keys: Vec<u64> = pending.iter().map(|r| *r.key()).collect();
202    for key in keys {
203        if let Some((_, sender)) = pending.remove(&key) {
204            let _ = sender.send(Err(ClientError::Transport(error.to_string())));
205        }
206    }
207}
208
209/// Compress data with deflate (flate2).
210fn compress_deflate(data: &[u8]) -> Vec<u8> {
211    use flate2::Compression;
212    use flate2::write::DeflateEncoder;
213    use std::io::Write;
214
215    let mut encoder = DeflateEncoder::new(Vec::new(), Compression::fast());
216    let _ = encoder.write_all(data);
217    encoder.finish().unwrap_or_else(|_| data.to_vec())
218}
219
220/// Maximum decompressed payload size (10 MB) to prevent zip bomb attacks.
221const MAX_DECOMPRESSED_SIZE: u64 = 10 * 1024 * 1024;
222
223/// Decompress deflate-compressed data.
224fn decompress_deflate(data: &[u8]) -> Result<String, std::io::Error> {
225    use flate2::read::DeflateDecoder;
226    use std::io::Read;
227
228    let decoder = DeflateDecoder::new(data);
229    let mut limited = decoder.take(MAX_DECOMPRESSED_SIZE);
230    let mut output = String::new();
231    limited.read_to_string(&mut output)?;
232    Ok(output)
233}
234
235/// Minimum payload size (bytes) to consider compressing with deflate.
236const COMPRESSION_THRESHOLD: usize = 512;
237
238/// Extract the host (with optional port) from a URL string.
239fn extract_host(url: &str) -> Option<String> {
240    let parsed = url::Url::parse(url).ok()?;
241    let host = parsed.host_str()?.to_string();
242    match parsed.port() {
243        Some(port) => Some(format!("{}:{}", host, port)),
244        None => Some(host),
245    }
246}
247
248impl Transport for WsTransport {
249    async fn call(&self, op: &str, req: &HGrid) -> Result<HGrid, ClientError> {
250        // Bounded pending map check.
251        if self.pending.len() >= MAX_PENDING_REQUESTS {
252            return Err(ClientError::TooManyRequests);
253        }
254
255        let codec = codec_for("text/zinc")
256            .ok_or_else(|| ClientError::Codec("zinc codec not available".to_string()))?;
257
258        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
259
260        let body = codec
261            .encode_grid(req)
262            .map_err(|e| ClientError::Codec(e.to_string()))?;
263
264        let envelope = serde_json::json!({
265            "id": id.to_string(),
266            "op": op,
267            "body": body,
268        });
269
270        let msg_text =
271            serde_json::to_string(&envelope).map_err(|e| ClientError::Codec(e.to_string()))?;
272
273        // Compress large payloads and send as binary frame.
274        let ws_msg = if msg_text.len() >= COMPRESSION_THRESHOLD {
275            let compressed = compress_deflate(msg_text.as_bytes());
276            if compressed.len() < msg_text.len() {
277                tungstenite::Message::Binary(compressed.into())
278            } else {
279                tungstenite::Message::Text(msg_text.into())
280            }
281        } else {
282            tungstenite::Message::Text(msg_text.into())
283        };
284
285        // Register a oneshot channel for this request.
286        let (tx, rx) = oneshot::channel();
287        self.pending.insert(id, tx);
288
289        // Send the request.
290        {
291            let mut writer = self.writer.lock().await;
292            if let Err(e) = writer.send(ws_msg).await {
293                self.pending.remove(&id);
294                return Err(ClientError::Transport(e.to_string()));
295            }
296        }
297
298        // Await the response with a timeout.
299        let timeout = self.request_timeout;
300        match tokio::time::timeout(timeout, rx).await {
301            Ok(Ok(result)) => result,
302            Ok(Err(_)) => Err(ClientError::Transport(
303                "response channel closed unexpectedly".to_string(),
304            )),
305            Err(_) => {
306                self.pending.remove(&id);
307                Err(ClientError::Timeout(timeout))
308            }
309        }
310    }
311
312    async fn close(&self) -> Result<(), ClientError> {
313        self.shutdown.cancel();
314        let mut writer = self.writer.lock().await;
315        writer
316            .send(tungstenite::Message::Close(None))
317            .await
318            .map_err(|e| ClientError::Transport(e.to_string()))?;
319        Ok(())
320    }
321}
322
323impl Drop for WsTransport {
324    fn drop(&mut self) {
325        self.shutdown.cancel();
326    }
327}
328
329// ---------------------------------------------------------------------------
330// Reconnecting transport wrapper
331// ---------------------------------------------------------------------------
332
333/// Initial backoff delay before the first reconnection attempt.
334const INITIAL_BACKOFF: Duration = Duration::from_millis(250);
335/// Maximum backoff delay between reconnection attempts.
336const MAX_BACKOFF: Duration = Duration::from_secs(30);
337/// Maximum number of consecutive reconnection attempts before giving up.
338const MAX_RECONNECT_ATTEMPTS: u32 = 10;
339
340/// A WebSocket transport that automatically reconnects on connection loss.
341///
342/// Uses exponential backoff with jitter between reconnection attempts.
343/// Requests that arrive during reconnection are queued and retried once the
344/// connection is re-established.
345pub struct ReconnectingWsTransport {
346    url: String,
347    auth_token: zeroize::Zeroizing<String>,
348    request_timeout: Duration,
349    inner: Mutex<Option<Arc<WsTransport>>>,
350}
351
352impl ReconnectingWsTransport {
353    /// Create a new reconnecting transport.  An initial connection is
354    /// established immediately; use [`Self::connect`] for the async builder.
355    pub async fn connect(url: &str, auth_token: &str) -> Result<Self, ClientError> {
356        let transport = WsTransport::connect(url, auth_token).await?;
357        Ok(Self {
358            url: url.to_string(),
359            auth_token: zeroize::Zeroizing::new(auth_token.to_string()),
360            request_timeout: DEFAULT_REQUEST_TIMEOUT,
361            inner: Mutex::new(Some(Arc::new(transport))),
362        })
363    }
364
365    /// Create a new reconnecting transport with a custom request timeout.
366    pub async fn connect_with_timeout(
367        url: &str,
368        auth_token: &str,
369        timeout: Duration,
370    ) -> Result<Self, ClientError> {
371        let transport = WsTransport::connect_with_timeout(url, auth_token, timeout).await?;
372        Ok(Self {
373            url: url.to_string(),
374            auth_token: zeroize::Zeroizing::new(auth_token.to_string()),
375            request_timeout: timeout,
376            inner: Mutex::new(Some(Arc::new(transport))),
377        })
378    }
379
380    /// Try to reconnect using exponential backoff with jitter.
381    /// Returns `Ok(())` when a new connection is established, or `Err` after
382    /// exhausting all attempts.
383    async fn reconnect(&self) -> Result<(), ClientError> {
384        use rand::RngExt;
385
386        let mut backoff = INITIAL_BACKOFF;
387
388        for attempt in 1..=MAX_RECONNECT_ATTEMPTS {
389            // Add random jitter: ±25% of current backoff.
390            let jitter_range = backoff.as_millis() as u64 / 4;
391            let jitter = if jitter_range > 0 {
392                let offset = rand::rng().random_range(0..jitter_range * 2);
393                Duration::from_millis(offset)
394            } else {
395                Duration::ZERO
396            };
397            let delay = backoff
398                .saturating_add(jitter)
399                .saturating_sub(Duration::from_millis(jitter_range));
400            tokio::time::sleep(delay).await;
401
402            match WsTransport::connect_with_timeout(
403                &self.url,
404                &self.auth_token,
405                self.request_timeout,
406            )
407            .await
408            {
409                Ok(transport) => {
410                    *self.inner.lock().await = Some(Arc::new(transport));
411                    return Ok(());
412                }
413                Err(_) if attempt < MAX_RECONNECT_ATTEMPTS => {
414                    backoff = (backoff * 2).min(MAX_BACKOFF);
415                    continue;
416                }
417                Err(e) => {
418                    return Err(ClientError::Transport(format!(
419                        "reconnection failed after {MAX_RECONNECT_ATTEMPTS} attempts: {e}"
420                    )));
421                }
422            }
423        }
424
425        Err(ClientError::Transport(
426            "reconnection failed: max attempts exhausted".to_string(),
427        ))
428    }
429}
430
431impl Transport for ReconnectingWsTransport {
432    async fn call(&self, op: &str, req: &HGrid) -> Result<HGrid, ClientError> {
433        // Fast path: clone the Arc out of the lock, then drop the lock before calling.
434        let transport = {
435            let guard = self.inner.lock().await;
436            guard.as_ref().cloned()
437        };
438        if let Some(transport) = transport {
439            match transport.call(op, req).await {
440                Ok(grid) => return Ok(grid),
441                Err(ClientError::Timeout(d)) => return Err(ClientError::Timeout(d)),
442                Err(ClientError::ServerError(e)) => return Err(ClientError::ServerError(e)),
443                Err(ClientError::TooManyRequests) => {
444                    return Err(ClientError::TooManyRequests);
445                }
446                Err(_) => {
447                    // Connection-level error; fall through to reconnect.
448                }
449            }
450        }
451
452        // Drop current transport and reconnect.
453        *self.inner.lock().await = None;
454        self.reconnect().await?;
455
456        // Retry the request on the new connection.
457        let transport = {
458            let guard = self.inner.lock().await;
459            guard.as_ref().cloned()
460        };
461        match transport {
462            Some(transport) => transport.call(op, req).await,
463            None => Err(ClientError::ConnectionClosed),
464        }
465    }
466
467    async fn close(&self) -> Result<(), ClientError> {
468        let transport = self.inner.lock().await.take();
469        if let Some(transport) = transport {
470            transport.close().await
471        } else {
472            Ok(())
473        }
474    }
475}