titan_rust_client/
connection.rs

1//! WebSocket connection management with auto-reconnect and stream resumption.
2
3use std::collections::HashMap;
4use std::sync::atomic::{AtomicU32, Ordering};
5use std::sync::Arc;
6use std::time::Duration;
7
8use futures_util::{SinkExt, StreamExt};
9use titan_api_codec::codec::ws::v1::ClientCodec;
10use titan_api_codec::codec::Codec;
11use titan_api_types::ws::v1::{
12    ClientRequest, RequestData, ResponseError, ResponseSuccess, ServerMessage, StreamData,
13    SwapQuoteRequest,
14};
15use tokio::net::TcpStream;
16use tokio::sync::{mpsc, oneshot, RwLock};
17use tokio_tungstenite::tungstenite::Message;
18use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
19
20use crate::config::TitanConfig;
21use crate::error::TitanClientError;
22use crate::state::ConnectionState;
23
24type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
25type ResponseResult = Result<ResponseSuccess, ResponseError>;
26type PendingRequestsMap = Arc<RwLock<HashMap<u32, oneshot::Sender<ResponseResult>>>>;
27
28/// Initial backoff delay in milliseconds.
29const INITIAL_BACKOFF_MS: u64 = 100;
30
31/// Information needed to resume a stream after reconnection.
32#[derive(Clone)]
33pub struct ResumableStream {
34    /// The original request used to create the stream.
35    pub request: SwapQuoteRequest,
36    /// Channel to send stream data to.
37    pub sender: mpsc::Sender<StreamData>,
38}
39
40type ResumableStreamsMap = Arc<RwLock<HashMap<u32, ResumableStream>>>;
41
42/// Internal message for sending requests through the connection
43pub struct PendingRequest {
44    pub request: ClientRequest,
45    pub response_tx: oneshot::Sender<ResponseResult>,
46}
47
48/// Manages a WebSocket connection to the Titan API with auto-reconnect.
49pub struct Connection {
50    #[allow(dead_code)]
51    config: TitanConfig,
52    request_id: AtomicU32,
53    sender: mpsc::Sender<PendingRequest>,
54    state_tx: tokio::sync::watch::Sender<ConnectionState>,
55    #[allow(dead_code)]
56    pending_requests: PendingRequestsMap,
57    resumable_streams: ResumableStreamsMap,
58}
59
60impl Connection {
61    /// Create a new connection with the given config.
62    ///
63    /// Connects eagerly and auto-reconnects on disconnection.
64    #[tracing::instrument(skip_all)]
65    pub async fn connect(config: TitanConfig) -> Result<Self, TitanClientError> {
66        let (state_tx, _state_rx) = tokio::sync::watch::channel(ConnectionState::Disconnected {
67            reason: "Connecting...".to_string(),
68        });
69
70        let pending_requests: PendingRequestsMap = Arc::new(RwLock::new(HashMap::new()));
71        let resumable_streams: ResumableStreamsMap = Arc::new(RwLock::new(HashMap::new()));
72
73        // Connect to WebSocket
74        let ws_stream = Self::establish_connection(&config).await?;
75
76        // Create channel for sending requests
77        let (sender, receiver) = mpsc::channel::<PendingRequest>(32);
78
79        // Spawn background task with reconnection support
80        let pending_clone = pending_requests.clone();
81        let streams_clone = resumable_streams.clone();
82        let state_tx_clone = state_tx.clone();
83        let config_clone = config.clone();
84
85        tokio::spawn(Self::run_connection_loop_with_reconnect(
86            ws_stream,
87            receiver,
88            pending_clone,
89            streams_clone,
90            state_tx_clone,
91            config_clone,
92        ));
93
94        state_tx.send_replace(ConnectionState::Connected);
95
96        Ok(Self {
97            config,
98            request_id: AtomicU32::new(1),
99            sender,
100            state_tx,
101            pending_requests,
102            resumable_streams,
103        })
104    }
105
106    /// Establish WebSocket connection with authentication.
107    async fn establish_connection(config: &TitanConfig) -> Result<WsStream, TitanClientError> {
108        use tokio_tungstenite::tungstenite::client::IntoClientRequest;
109        use tokio_tungstenite::Connector;
110
111        // Build URL with auth token as query param
112        // Only add trailing slash if URL has no path (e.g., ws://host:port -> ws://host:port/)
113        let url = if config.url.contains("/ws") || config.url.ends_with('/') {
114            // URL already has a path, just append query param
115            format!("{}?auth={}", config.url, config.token)
116        } else {
117            // URL has no path, add trailing slash first
118            format!("{}/?auth={}", config.url, config.token)
119        };
120
121        let mut request = url.into_client_request().map_err(|e| {
122            TitanClientError::Unexpected(anyhow::anyhow!("Failed to build request: {}", e))
123        })?;
124
125        // Add the Titan subprotocol header
126        request.headers_mut().insert(
127            "Sec-WebSocket-Protocol",
128            titan_api_types::ws::v1::WEBSOCKET_SUBPROTO_BASE
129                .parse()
130                .unwrap(),
131        );
132
133        let (ws_stream, _response) = if config.danger_accept_invalid_certs {
134            let tls_config = crate::tls::build_dangerous_tls_config();
135            let connector = Connector::Rustls(Arc::new(tls_config));
136            tokio_tungstenite::connect_async_tls_with_config(request, None, false, Some(connector))
137                .await
138                .map_err(TitanClientError::WebSocket)?
139        } else {
140            tokio_tungstenite::connect_async(request)
141                .await
142                .map_err(TitanClientError::WebSocket)?
143        };
144
145        Ok(ws_stream)
146    }
147
148    /// Connection loop with automatic reconnection and stream resumption.
149    async fn run_connection_loop_with_reconnect(
150        initial_ws_stream: WsStream,
151        mut request_rx: mpsc::Receiver<PendingRequest>,
152        pending_requests: PendingRequestsMap,
153        resumable_streams: ResumableStreamsMap,
154        state_tx: tokio::sync::watch::Sender<ConnectionState>,
155        config: TitanConfig,
156    ) {
157        let mut ws_stream = initial_ws_stream;
158        let mut reconnect_attempt: u32 = 0;
159        let mut request_id_counter: u32 = 1;
160
161        loop {
162            // Run the connection loop until disconnection
163            let disconnect_reason = Self::run_single_connection(
164                &mut ws_stream,
165                &mut request_rx,
166                &pending_requests,
167                &resumable_streams,
168                &state_tx,
169                &mut request_id_counter,
170            )
171            .await;
172
173            // Check if request channel is closed (client dropped)
174            if request_rx.is_closed() {
175                tracing::info!("Request channel closed, shutting down connection");
176                break;
177            }
178
179            // Start reconnection attempts
180            reconnect_attempt += 1;
181
182            // Check max attempts
183            if let Some(max) = config.max_reconnect_attempts {
184                if reconnect_attempt > max {
185                    tracing::error!("Max reconnect attempts ({}) reached, giving up", max);
186                    let _ = state_tx.send(ConnectionState::Disconnected {
187                        reason: format!(
188                            "Max reconnect attempts reached. Last error: {}",
189                            disconnect_reason
190                        ),
191                    });
192                    break;
193                }
194            }
195
196            // Calculate backoff delay with exponential increase
197            let backoff_ms = calculate_backoff(reconnect_attempt, config.max_reconnect_delay_ms);
198
199            tracing::info!(
200                attempt = reconnect_attempt,
201                backoff_ms,
202                "Reconnecting after disconnection: {}",
203                disconnect_reason
204            );
205
206            let _ = state_tx.send(ConnectionState::Reconnecting {
207                attempt: reconnect_attempt,
208            });
209
210            // Wait before reconnecting
211            tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
212
213            // Attempt to reconnect
214            match Self::establish_connection(&config).await {
215                Ok(new_stream) => {
216                    ws_stream = new_stream;
217                    reconnect_attempt = 0;
218                    let _ = state_tx.send(ConnectionState::Connected);
219                    tracing::info!("Reconnected successfully");
220
221                    // Resume streams after reconnection
222                    Self::resume_streams(
223                        &mut ws_stream,
224                        &resumable_streams,
225                        &mut request_id_counter,
226                    )
227                    .await;
228                }
229                Err(e) => {
230                    tracing::warn!("Reconnection failed: {}", e);
231                    continue;
232                }
233            }
234        }
235
236        // Final cleanup
237        Self::cleanup_pending_requests(&pending_requests).await;
238    }
239
240    /// Resume all active streams after reconnection.
241    async fn resume_streams(
242        ws_stream: &mut WsStream,
243        resumable_streams: &ResumableStreamsMap,
244        request_id_counter: &mut u32,
245    ) {
246        let streams_to_resume: Vec<(u32, ResumableStream)> = {
247            let streams = resumable_streams.read().await;
248            streams.iter().map(|(k, v)| (*k, v.clone())).collect()
249        };
250
251        if streams_to_resume.is_empty() {
252            return;
253        }
254
255        tracing::info!(
256            "Resuming {} streams after reconnection",
257            streams_to_resume.len()
258        );
259
260        let codec = ClientCodec::Uncompressed;
261        let mut encoder = codec.encoder();
262        let mut decoder = codec.decoder();
263
264        for (old_stream_id, resumable) in streams_to_resume {
265            let request_id = *request_id_counter;
266            *request_id_counter += 1;
267
268            let request = ClientRequest {
269                id: request_id,
270                data: RequestData::NewSwapQuoteStream(resumable.request.clone()),
271            };
272
273            // Encode and send the request
274            let encoded = match encoder.encode_mut(&request) {
275                Ok(data) => data.to_vec(),
276                Err(e) => {
277                    tracing::error!("Failed to encode stream resume request: {}", e);
278                    continue;
279                }
280            };
281
282            if let Err(e) = ws_stream.send(Message::Binary(encoded.into())).await {
283                tracing::error!("Failed to send stream resume request: {}", e);
284                continue;
285            }
286
287            // Wait for response to get new stream ID
288            match ws_stream.next().await {
289                Some(Ok(Message::Binary(data))) => {
290                    match decoder.decode_mut(data) {
291                        Ok(ServerMessage::Response(response)) => {
292                            if let Some(stream_info) = response.stream {
293                                let new_stream_id = stream_info.id;
294
295                                // Update the stream mapping
296                                let mut streams = resumable_streams.write().await;
297                                if let Some(stream) = streams.remove(&old_stream_id) {
298                                    streams.insert(new_stream_id, stream);
299                                    tracing::info!(
300                                        old_id = old_stream_id,
301                                        new_id = new_stream_id,
302                                        "Stream resumed with new ID"
303                                    );
304                                }
305                            }
306                        }
307                        Ok(ServerMessage::Error(error)) => {
308                            tracing::error!(
309                                "Failed to resume stream {}: {}",
310                                old_stream_id,
311                                error.message
312                            );
313                            // Remove the failed stream
314                            let mut streams = resumable_streams.write().await;
315                            streams.remove(&old_stream_id);
316                        }
317                        Ok(_) => {
318                            tracing::warn!("Unexpected response type during stream resumption");
319                        }
320                        Err(e) => {
321                            tracing::error!("Failed to decode stream resume response: {}", e);
322                        }
323                    }
324                }
325                Some(Ok(_)) => {
326                    tracing::warn!("Unexpected message type during stream resumption");
327                }
328                Some(Err(e)) => {
329                    tracing::error!("WebSocket error during stream resumption: {}", e);
330                    break;
331                }
332                None => {
333                    tracing::error!("Connection closed during stream resumption");
334                    break;
335                }
336            }
337        }
338    }
339
340    /// Run a single connection until disconnection.
341    async fn run_single_connection(
342        ws_stream: &mut WsStream,
343        request_rx: &mut mpsc::Receiver<PendingRequest>,
344        pending_requests: &PendingRequestsMap,
345        resumable_streams: &ResumableStreamsMap,
346        state_tx: &tokio::sync::watch::Sender<ConnectionState>,
347        request_id_counter: &mut u32,
348    ) -> String {
349        let codec = ClientCodec::Uncompressed;
350        let mut encoder = codec.encoder();
351        let mut decoder = codec.decoder();
352
353        let (mut ws_sink, mut ws_stream_rx) = ws_stream.split();
354
355        loop {
356            tokio::select! {
357                Some(pending_req) = request_rx.recv() => {
358                    let request_id = pending_req.request.id;
359                    *request_id_counter = request_id.max(*request_id_counter) + 1;
360
361                    {
362                        let mut pending_map = pending_requests.write().await;
363                        pending_map.insert(request_id, pending_req.response_tx);
364                    }
365
366                    match encoder.encode_mut(&pending_req.request) {
367                        Ok(data) => {
368                            if let Err(e) = ws_sink.send(Message::Binary(data.to_vec().into())).await {
369                                tracing::error!("Failed to send WebSocket message: {}", e);
370                                let mut pending_map = pending_requests.write().await;
371                                if let Some(tx) = pending_map.remove(&request_id) {
372                                    let _ = tx.send(Err(ResponseError {
373                                        request_id,
374                                        code: 0,
375                                        message: format!("Send failed: {}", e),
376                                    }));
377                                }
378                            }
379                        }
380                        Err(e) => {
381                            tracing::error!("Failed to encode request: {}", e);
382                            let mut pending_map = pending_requests.write().await;
383                            if let Some(tx) = pending_map.remove(&request_id) {
384                                let _ = tx.send(Err(ResponseError {
385                                    request_id,
386                                    code: 0,
387                                    message: format!("Encode failed: {}", e),
388                                }));
389                            }
390                        }
391                    }
392                }
393
394                Some(msg_result) = ws_stream_rx.next() => {
395                    match msg_result {
396                        Ok(Message::Binary(data)) => {
397                            match decoder.decode_mut(data) {
398                                Ok(server_msg) => {
399                                    Self::handle_server_message(
400                                        server_msg,
401                                        pending_requests,
402                                        resumable_streams,
403                                    ).await;
404                                }
405                                Err(e) => {
406                                    tracing::error!("Failed to decode server message: {}", e);
407                                }
408                            }
409                        }
410                        Ok(Message::Close(frame)) => {
411                            let reason = frame
412                                .map(|f| f.reason.to_string())
413                                .unwrap_or_else(|| "Server closed connection".to_string());
414                            tracing::warn!("WebSocket closed: {}", reason);
415                            let _ = state_tx.send(ConnectionState::Disconnected {
416                                reason: reason.clone(),
417                            });
418                            return reason;
419                        }
420                        Ok(Message::Ping(data)) => {
421                            let _ = ws_sink.send(Message::Pong(data)).await;
422                        }
423                        Ok(_) => {}
424                        Err(e) => {
425                            let reason = format!("WebSocket error: {}", e);
426                            let error_str = e.to_string();
427                            if error_str.contains("Connection reset without closing handshake") {
428                                tracing::info!("{}", reason);
429                            } else {
430                                tracing::error!("{}", reason);
431                            }
432                            let _ = state_tx.send(ConnectionState::Disconnected {
433                                reason: reason.clone(),
434                            });
435                            return reason;
436                        }
437                    }
438                }
439
440                else => {
441                    return "Channel closed".to_string();
442                }
443            }
444        }
445    }
446
447    /// Handle a message received from the server.
448    async fn handle_server_message(
449        msg: ServerMessage,
450        pending_requests: &PendingRequestsMap,
451        resumable_streams: &ResumableStreamsMap,
452    ) {
453        match msg {
454            ServerMessage::Response(response) => {
455                let mut pending = pending_requests.write().await;
456                if let Some(tx) = pending.remove(&response.request_id) {
457                    let _ = tx.send(Ok(response));
458                }
459            }
460            ServerMessage::Error(error) => {
461                let mut pending = pending_requests.write().await;
462                if let Some(tx) = pending.remove(&error.request_id) {
463                    let _ = tx.send(Err(error));
464                }
465            }
466            ServerMessage::StreamData(data) => {
467                let streams = resumable_streams.read().await;
468                if let Some(stream) = streams.get(&data.id) {
469                    let _ = stream.sender.send(data).await;
470                }
471            }
472            ServerMessage::StreamEnd(end) => {
473                let mut streams = resumable_streams.write().await;
474                streams.remove(&end.id);
475            }
476            ServerMessage::Other(_) => {
477                tracing::warn!("Received unknown server message type");
478            }
479        }
480    }
481
482    /// Cleanup pending requests on final shutdown.
483    async fn cleanup_pending_requests(pending_requests: &PendingRequestsMap) {
484        let mut pending_map = pending_requests.write().await;
485        for (request_id, tx) in pending_map.drain() {
486            let _ = tx.send(Err(ResponseError {
487                request_id,
488                code: 0,
489                message: "Connection closed".to_string(),
490            }));
491        }
492    }
493
494    /// Send a request and wait for response.
495    #[tracing::instrument(skip_all)]
496    pub async fn send_request(
497        &self,
498        data: RequestData,
499    ) -> Result<ResponseSuccess, TitanClientError> {
500        let request_id = self.request_id.fetch_add(1, Ordering::SeqCst);
501        let request = ClientRequest {
502            id: request_id,
503            data,
504        };
505
506        let (response_tx, response_rx) = oneshot::channel();
507
508        self.sender
509            .send(PendingRequest {
510                request,
511                response_tx,
512            })
513            .await
514            .map_err(|_| TitanClientError::Unexpected(anyhow::anyhow!("Connection closed")))?;
515
516        let response = response_rx.await.map_err(|_| {
517            TitanClientError::Unexpected(anyhow::anyhow!("Response channel closed"))
518        })?;
519
520        response.map_err(|e| TitanClientError::ServerError {
521            code: e.code,
522            message: e.message,
523        })
524    }
525
526    /// Register a resumable stream.
527    pub async fn register_stream(
528        &self,
529        stream_id: u32,
530        request: SwapQuoteRequest,
531        sender: mpsc::Sender<StreamData>,
532    ) {
533        let mut streams = self.resumable_streams.write().await;
534        streams.insert(stream_id, ResumableStream { request, sender });
535    }
536
537    /// Unregister a stream.
538    pub async fn unregister_stream(&self, stream_id: u32) {
539        let mut streams = self.resumable_streams.write().await;
540        streams.remove(&stream_id);
541    }
542
543    /// Get a receiver for connection state changes.
544    pub fn state_receiver(&self) -> tokio::sync::watch::Receiver<ConnectionState> {
545        self.state_tx.subscribe()
546    }
547
548    /// Get the current connection state.
549    pub fn state(&self) -> ConnectionState {
550        self.state_tx.borrow().clone()
551    }
552
553    /// Get all active stream IDs.
554    pub async fn active_stream_ids(&self) -> Vec<u32> {
555        let streams = self.resumable_streams.read().await;
556        streams.keys().copied().collect()
557    }
558
559    /// Stop all active streams gracefully.
560    ///
561    /// Sends StopStream for each active stream and clears the stream map.
562    #[tracing::instrument(skip_all)]
563    pub async fn stop_all_streams(&self) {
564        use titan_api_types::ws::v1::StopStreamRequest;
565
566        let stream_ids = self.active_stream_ids().await;
567
568        if stream_ids.is_empty() {
569            return;
570        }
571
572        tracing::info!("Stopping {} active streams", stream_ids.len());
573
574        for stream_id in stream_ids {
575            // Send stop request (fire and forget)
576            let _ = self
577                .send_request(RequestData::StopStream(StopStreamRequest { id: stream_id }))
578                .await;
579        }
580
581        // Clear all streams
582        let mut streams = self.resumable_streams.write().await;
583        streams.clear();
584    }
585
586    /// Graceful shutdown: stop all streams and signal connection loop to exit.
587    #[tracing::instrument(skip_all)]
588    pub async fn shutdown(&self) {
589        // Stop all streams first
590        self.stop_all_streams().await;
591
592        // Update state
593        let _ = self.state_tx.send(ConnectionState::Disconnected {
594            reason: "Client shutdown".to_string(),
595        });
596
597        // The connection loop will exit when it detects the sender is closed
598        // (which happens when Connection is dropped)
599    }
600}
601
602/// Calculate exponential backoff.
603fn calculate_backoff(attempt: u32, max_delay_ms: u64) -> u64 {
604    let base_delay = INITIAL_BACKOFF_MS * 2u64.saturating_pow(attempt.saturating_sub(1));
605    base_delay.min(max_delay_ms)
606}