Skip to main content

titan_rust_client/
connection.rs

1//! WebSocket connection management with auto-reconnect and stream resumption.
2
3use std::collections::HashMap;
4use std::sync::atomic::{AtomicU32, Ordering};
5use std::sync::Arc;
6use std::time::Duration;
7
8use futures_util::{SinkExt, StreamExt};
9use titan_api_codec::codec::ws::v1::ClientCodec;
10use titan_api_codec::codec::Codec;
11use titan_api_types::ws::v1::{
12    ClientRequest, RequestData, ResponseError, ResponseSuccess, ServerMessage, StreamData,
13    SwapQuoteRequest,
14};
15use tokio::net::TcpStream;
16use tokio::sync::{mpsc, oneshot, RwLock};
17use tokio_tungstenite::tungstenite::Message;
18use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
19
20use crate::config::TitanConfig;
21use crate::error::TitanClientError;
22use crate::state::ConnectionState;
23
24type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
25type ResponseResult = Result<ResponseSuccess, ResponseError>;
26type PendingRequestsMap = Arc<RwLock<HashMap<u32, oneshot::Sender<ResponseResult>>>>;
27
28/// Initial backoff delay in milliseconds.
29pub const INITIAL_BACKOFF_MS: u64 = 100;
30
31/// Default ping interval in milliseconds (used if config value is 0).
32pub const DEFAULT_PING_INTERVAL_MS: u64 = 25_000;
33
34/// Default pong timeout in milliseconds — if no pong received within this window, reconnect.
35pub const DEFAULT_PONG_TIMEOUT_MS: u64 = 10_000;
36
37/// Information needed to resume a stream after reconnection.
38#[derive(Clone)]
39pub struct ResumableStream {
40    /// The original request used to create the stream.
41    pub request: SwapQuoteRequest,
42    /// Channel to send stream data to.
43    pub sender: mpsc::Sender<StreamData>,
44}
45
46type ResumableStreamsMap = Arc<RwLock<HashMap<u32, ResumableStream>>>;
47
48/// Internal message for sending requests through the connection
49pub struct PendingRequest {
50    pub request: ClientRequest,
51    pub response_tx: oneshot::Sender<ResponseResult>,
52}
53
54/// Manages a WebSocket connection to the Titan API with auto-reconnect.
55pub struct Connection {
56    #[expect(dead_code)]
57    config: TitanConfig,
58    request_id: AtomicU32,
59    sender: mpsc::Sender<PendingRequest>,
60    state_tx: tokio::sync::watch::Sender<ConnectionState>,
61    #[expect(dead_code)]
62    pending_requests: PendingRequestsMap,
63    resumable_streams: ResumableStreamsMap,
64}
65
66impl Connection {
67    /// Create a new connection with the given config.
68    ///
69    /// Connects eagerly and auto-reconnects on disconnection.
70    #[tracing::instrument(skip_all)]
71    pub async fn connect(config: TitanConfig) -> Result<Self, TitanClientError> {
72        let (state_tx, _state_rx) = tokio::sync::watch::channel(ConnectionState::Disconnected {
73            reason: "Connecting...".to_string(),
74        });
75
76        let pending_requests: PendingRequestsMap = Arc::new(RwLock::new(HashMap::new()));
77        let resumable_streams: ResumableStreamsMap = Arc::new(RwLock::new(HashMap::new()));
78
79        // Connect to WebSocket
80        let ws_stream = Self::establish_connection(&config).await?;
81
82        // Create channel for sending requests
83        let (sender, receiver) = mpsc::channel::<PendingRequest>(32);
84
85        // Spawn background task with reconnection support
86        let pending_clone = pending_requests.clone();
87        let streams_clone = resumable_streams.clone();
88        let state_tx_clone = state_tx.clone();
89        let config_clone = config.clone();
90
91        tokio::spawn(Self::run_connection_loop_with_reconnect(
92            ws_stream,
93            receiver,
94            pending_clone,
95            streams_clone,
96            state_tx_clone,
97            config_clone,
98        ));
99
100        state_tx.send_replace(ConnectionState::Connected);
101
102        Ok(Self {
103            config,
104            request_id: AtomicU32::new(1),
105            sender,
106            state_tx,
107            pending_requests,
108            resumable_streams,
109        })
110    }
111
112    /// Establish WebSocket connection with authentication.
113    async fn establish_connection(config: &TitanConfig) -> Result<WsStream, TitanClientError> {
114        use tokio_tungstenite::tungstenite::client::IntoClientRequest;
115        use tokio_tungstenite::Connector;
116
117        // Build URL with auth token as query param
118        // Only add trailing slash if URL has no path (e.g., ws://host:port -> ws://host:port/)
119        let url = if config.url.contains("/ws") || config.url.ends_with('/') {
120            // URL already has a path, just append query param
121            format!("{}?auth={}", config.url, config.token)
122        } else {
123            // URL has no path, add trailing slash first
124            format!("{}/?auth={}", config.url, config.token)
125        };
126
127        let mut request = url.into_client_request().map_err(|e| {
128            TitanClientError::Unexpected(anyhow::anyhow!("Failed to build request: {}", e))
129        })?;
130
131        // Add the Titan subprotocol header
132        request.headers_mut().insert(
133            "Sec-WebSocket-Protocol",
134            titan_api_types::ws::v1::WEBSOCKET_SUBPROTO_BASE
135                .parse()
136                .map_err(|e| {
137                    TitanClientError::Unexpected(anyhow::anyhow!(
138                        "Sec-WebSocket-Protocol fail: {e}"
139                    ))
140                })?,
141        );
142
143        let (ws_stream, _response) = if config.danger_accept_invalid_certs {
144            let tls_config = crate::tls::build_dangerous_tls_config();
145            let connector = Connector::Rustls(Arc::new(tls_config));
146            tokio_tungstenite::connect_async_tls_with_config(request, None, false, Some(connector))
147                .await
148                .map_err(TitanClientError::WebSocket)?
149        } else {
150            tokio_tungstenite::connect_async(request)
151                .await
152                .map_err(TitanClientError::WebSocket)?
153        };
154
155        Ok(ws_stream)
156    }
157
158    /// Connection loop with automatic reconnection and stream resumption.
159    async fn run_connection_loop_with_reconnect(
160        initial_ws_stream: WsStream,
161        mut request_rx: mpsc::Receiver<PendingRequest>,
162        pending_requests: PendingRequestsMap,
163        resumable_streams: ResumableStreamsMap,
164        state_tx: tokio::sync::watch::Sender<ConnectionState>,
165        config: TitanConfig,
166    ) {
167        let mut ws_stream = initial_ws_stream;
168        let mut reconnect_attempt: u32 = 0;
169        let mut request_id_counter: u32 = 1;
170
171        loop {
172            // Run the connection loop until disconnection
173            let disconnect_reason = Self::run_single_connection(
174                &mut ws_stream,
175                &mut request_rx,
176                &pending_requests,
177                &resumable_streams,
178                &state_tx,
179                &mut request_id_counter,
180                &config,
181            )
182            .await;
183
184            // Check if request channel is closed (client dropped)
185            if request_rx.is_closed() {
186                tracing::info!("Request channel closed, shutting down connection");
187                break;
188            }
189
190            // Start reconnection attempts
191            reconnect_attempt += 1;
192
193            // Check max attempts
194            if let Some(max) = config.max_reconnect_attempts {
195                if reconnect_attempt > max {
196                    tracing::error!("Max reconnect attempts ({}) reached, giving up", max);
197                    let _ = state_tx.send(ConnectionState::Disconnected {
198                        reason: format!(
199                            "Max reconnect attempts reached. Last error: {}",
200                            disconnect_reason
201                        ),
202                    });
203                    break;
204                }
205            }
206
207            // Calculate backoff delay with exponential increase
208            let backoff_ms = calculate_backoff(reconnect_attempt, config.max_reconnect_delay_ms);
209
210            tracing::debug!(
211                attempt = reconnect_attempt,
212                backoff_ms,
213                "Reconnecting after disconnection: {}",
214                disconnect_reason
215            );
216
217            let _ = state_tx.send(ConnectionState::Reconnecting {
218                attempt: reconnect_attempt,
219            });
220
221            // Wait before reconnecting
222            tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
223
224            // Attempt to reconnect
225            match Self::establish_connection(&config).await {
226                Ok(new_stream) => {
227                    ws_stream = new_stream;
228                    reconnect_attempt = 0;
229                    let _ = state_tx.send(ConnectionState::Connected);
230                    tracing::debug!("Reconnected successfully");
231
232                    // Resume streams after reconnection
233                    Self::resume_streams(
234                        &mut ws_stream,
235                        &resumable_streams,
236                        &mut request_id_counter,
237                    )
238                    .await;
239                }
240                Err(e) => {
241                    tracing::warn!("Reconnection failed: {}", e);
242                }
243            }
244        }
245
246        // Final cleanup
247        Self::cleanup_pending_requests(&pending_requests).await;
248    }
249
250    /// Resume all active streams after reconnection.
251    async fn resume_streams(
252        ws_stream: &mut WsStream,
253        resumable_streams: &ResumableStreamsMap,
254        request_id_counter: &mut u32,
255    ) {
256        let streams_to_resume: Vec<(u32, ResumableStream)> = {
257            let streams = resumable_streams.read().await;
258            streams.iter().map(|(k, v)| (*k, v.clone())).collect()
259        };
260
261        if streams_to_resume.is_empty() {
262            return;
263        }
264
265        tracing::info!(
266            "Resuming {} streams after reconnection",
267            streams_to_resume.len()
268        );
269
270        let codec = ClientCodec::Uncompressed;
271        let mut encoder = codec.encoder();
272        let mut decoder = codec.decoder();
273
274        for (old_stream_id, resumable) in streams_to_resume {
275            let request_id = *request_id_counter;
276            *request_id_counter += 1;
277
278            let request = ClientRequest {
279                id: request_id,
280                data: RequestData::NewSwapQuoteStream(resumable.request.clone()),
281            };
282
283            // Encode and send the request
284            let encoded = match encoder.encode_mut(&request) {
285                Ok(data) => data.to_vec(),
286                Err(e) => {
287                    tracing::error!("Failed to encode stream resume request: {}", e);
288                    continue;
289                }
290            };
291
292            if let Err(e) = ws_stream.send(Message::Binary(encoded.into())).await {
293                tracing::error!("Failed to send stream resume request: {}", e);
294                continue;
295            }
296
297            // Wait for response to get new stream ID
298            match ws_stream.next().await {
299                Some(Ok(Message::Binary(data))) => {
300                    match decoder.decode_mut(data) {
301                        Ok(ServerMessage::Response(response)) => {
302                            if let Some(stream_info) = response.stream {
303                                let new_stream_id = stream_info.id;
304
305                                // Update the stream mapping
306                                let mut streams = resumable_streams.write().await;
307                                if let Some(stream) = streams.remove(&old_stream_id) {
308                                    streams.insert(new_stream_id, stream);
309                                    tracing::info!(
310                                        old_id = old_stream_id,
311                                        new_id = new_stream_id,
312                                        "Stream resumed with new ID"
313                                    );
314                                }
315                            }
316                        }
317                        Ok(ServerMessage::Error(error)) => {
318                            tracing::error!(
319                                "Failed to resume stream {}: {}",
320                                old_stream_id,
321                                error.message
322                            );
323                            // Remove the failed stream
324                            let mut streams = resumable_streams.write().await;
325                            streams.remove(&old_stream_id);
326                        }
327                        Ok(_) => {
328                            tracing::warn!("Unexpected response type during stream resumption");
329                        }
330                        Err(e) => {
331                            tracing::error!("Failed to decode stream resume response: {}", e);
332                        }
333                    }
334                }
335                Some(Ok(_)) => {
336                    tracing::warn!("Unexpected message type during stream resumption");
337                }
338                Some(Err(e)) => {
339                    tracing::error!("WebSocket error during stream resumption: {}", e);
340                    break;
341                }
342                None => {
343                    tracing::error!("Connection closed during stream resumption");
344                    break;
345                }
346            }
347        }
348    }
349
350    /// Run a single connection until disconnection.
351    async fn run_single_connection(
352        ws_stream: &mut WsStream,
353        request_rx: &mut mpsc::Receiver<PendingRequest>,
354        pending_requests: &PendingRequestsMap,
355        resumable_streams: &ResumableStreamsMap,
356        state_tx: &tokio::sync::watch::Sender<ConnectionState>,
357        request_id_counter: &mut u32,
358        config: &TitanConfig,
359    ) -> String {
360        let codec = ClientCodec::Uncompressed;
361        let mut encoder = codec.encoder();
362        let mut decoder = codec.decoder();
363
364        let (mut ws_sink, mut ws_stream_rx) = ws_stream.split();
365
366        let ping_interval_ms = if config.ping_interval_ms > 0 {
367            config.ping_interval_ms
368        } else {
369            DEFAULT_PING_INTERVAL_MS
370        };
371        let mut ping_timer = tokio::time::interval(Duration::from_millis(ping_interval_ms));
372        ping_timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
373
374        let pong_timeout = Duration::from_millis(config.pong_timeout_ms);
375        let mut last_pong = tokio::time::Instant::now();
376
377        loop {
378            tokio::select! {
379                Some(pending_req) = request_rx.recv() => {
380                    let request_id = pending_req.request.id;
381                    *request_id_counter = request_id.max(*request_id_counter) + 1;
382
383                    {
384                        let mut pending_map = pending_requests.write().await;
385                        pending_map.insert(request_id, pending_req.response_tx);
386                    }
387
388                    match encoder.encode_mut(&pending_req.request) {
389                        Ok(data) => {
390                            if let Err(e) = ws_sink.send(Message::Binary(data.to_vec().into())).await {
391                                tracing::error!("Failed to send WebSocket message: {e}");
392                                let mut pending_map = pending_requests.write().await;
393                                if let Some(tx) = pending_map.remove(&request_id) {
394                                    let _ = tx.send(Err(ResponseError {
395                                        request_id,
396                                        code: 0,
397                                        message: format!("Send failed: {e}"),
398                                    }));
399                                }
400                            }
401                        }
402                        Err(e) => {
403                            tracing::error!("Failed to encode request: {e}");
404                            let mut pending_map = pending_requests.write().await;
405                            if let Some(tx) = pending_map.remove(&request_id) {
406                                let _ = tx.send(Err(ResponseError {
407                                    request_id,
408                                    code: 0,
409                                    message: format!("Encode failed: {e}"),
410                                }));
411                            }
412                        }
413                    }
414                }
415
416                Some(msg_result) = ws_stream_rx.next() => {
417                    match msg_result {
418                        Ok(Message::Binary(data)) => {
419                            match decoder.decode_mut(data) {
420                                Ok(server_msg) => {
421                                    Self::handle_server_message(
422                                        server_msg,
423                                        pending_requests,
424                                        resumable_streams,
425                                    ).await;
426                                }
427                                Err(e) => {
428                                    tracing::error!("Failed to decode server message: {e}");
429                                }
430                            }
431                        }
432                        Ok(Message::Close(frame)) => {
433                            let reason = frame.map_or_else(|| "Server closed connection".to_string(), |f| f.reason.to_string());
434                            tracing::warn!("WebSocket closed: {reason}");
435                            let _ = state_tx.send(ConnectionState::Disconnected {
436                                reason: reason.clone(),
437                            });
438                            return reason;
439                        }
440                        Ok(Message::Ping(data)) => {
441                            let _ = ws_sink.send(Message::Pong(data)).await;
442                        }
443                        Ok(Message::Pong(_)) => {
444                            last_pong = tokio::time::Instant::now();
445                            tracing::trace!("Received pong from server");
446                        }
447                        Ok(_) => {}
448                        Err(e) => {
449                            let reason = format!("WebSocket error: {e}");
450                            let error_str = e.to_string();
451                            if error_str.contains("Connection reset without closing handshake") {
452                                tracing::debug!("{reason}");
453                            } else {
454                                tracing::error!("{reason}");
455                            }
456                            let _ = state_tx.send(ConnectionState::Disconnected {
457                                reason: reason.clone(),
458                            });
459                            return reason;
460                        }
461                    }
462                }
463
464                _ = ping_timer.tick() => {
465                    if config.pong_timeout_ms > 0 && last_pong.elapsed() > pong_timeout {
466                        let reason = "Pong timeout".to_string();
467                        let timeout_ms = config.pong_timeout_ms;
468                        tracing::warn!("No pong received within {timeout_ms}ms, triggering reconnect");
469                        let _ = state_tx.send(ConnectionState::Disconnected {
470                            reason: reason.clone(),
471                        });
472                        return reason;
473                    }
474
475                    if let Err(e) = ws_sink.send(Message::Ping(vec![].into())).await {
476                        let reason = format!("Failed to send ping: {e}");
477                        tracing::warn!("{reason}");
478                        let _ = state_tx.send(ConnectionState::Disconnected {
479                            reason: reason.clone(),
480                        });
481                        return reason;
482                    }
483                    tracing::trace!("Sent keepalive ping");
484                }
485
486                else => {
487                    return "Channel closed".to_string();
488                }
489            }
490        }
491    }
492
493    /// Handle a message received from the server.
494    async fn handle_server_message(
495        msg: ServerMessage,
496        pending_requests: &PendingRequestsMap,
497        resumable_streams: &ResumableStreamsMap,
498    ) {
499        match msg {
500            ServerMessage::Response(response) => {
501                let mut pending = pending_requests.write().await;
502                if let Some(tx) = pending.remove(&response.request_id) {
503                    let _ = tx.send(Ok(response));
504                }
505            }
506            ServerMessage::Error(error) => {
507                let mut pending = pending_requests.write().await;
508                if let Some(tx) = pending.remove(&error.request_id) {
509                    let _ = tx.send(Err(error));
510                }
511            }
512            ServerMessage::StreamData(data) => {
513                let streams = resumable_streams.read().await;
514                if let Some(stream) = streams.get(&data.id) {
515                    let _ = stream.sender.send(data).await;
516                }
517            }
518            ServerMessage::StreamEnd(end) => {
519                let mut streams = resumable_streams.write().await;
520                streams.remove(&end.id);
521            }
522            ServerMessage::Other(_) => {
523                tracing::warn!("Received unknown server message type");
524            }
525        }
526    }
527
528    /// Cleanup pending requests on final shutdown.
529    async fn cleanup_pending_requests(pending_requests: &PendingRequestsMap) {
530        let mut pending_map = pending_requests.write().await;
531        for (request_id, tx) in pending_map.drain() {
532            let _ = tx.send(Err(ResponseError {
533                request_id,
534                code: 0,
535                message: "Connection closed".to_string(),
536            }));
537        }
538    }
539
540    /// Send a request and wait for response.
541    #[tracing::instrument(skip_all)]
542    pub async fn send_request(
543        &self,
544        data: RequestData,
545    ) -> Result<ResponseSuccess, TitanClientError> {
546        let request_id = self.request_id.fetch_add(1, Ordering::SeqCst);
547        let request = ClientRequest {
548            id: request_id,
549            data,
550        };
551
552        let (response_tx, response_rx) = oneshot::channel();
553
554        self.sender
555            .send(PendingRequest {
556                request,
557                response_tx,
558            })
559            .await
560            .map_err(|_| TitanClientError::Unexpected(anyhow::anyhow!("Connection closed")))?;
561
562        let response = response_rx.await.map_err(|_| {
563            TitanClientError::Unexpected(anyhow::anyhow!("Response channel closed"))
564        })?;
565
566        response.map_err(|e| TitanClientError::ServerError {
567            code: e.code,
568            message: e.message,
569        })
570    }
571
572    /// Register a resumable stream.
573    pub async fn register_stream(
574        &self,
575        stream_id: u32,
576        request: SwapQuoteRequest,
577        sender: mpsc::Sender<StreamData>,
578    ) {
579        let mut streams = self.resumable_streams.write().await;
580        streams.insert(stream_id, ResumableStream { request, sender });
581    }
582
583    /// Unregister a stream.
584    pub async fn unregister_stream(&self, stream_id: u32) {
585        let mut streams = self.resumable_streams.write().await;
586        streams.remove(&stream_id);
587    }
588
589    /// Get a receiver for connection state changes.
590    pub fn state_receiver(&self) -> tokio::sync::watch::Receiver<ConnectionState> {
591        self.state_tx.subscribe()
592    }
593
594    /// Get the current connection state.
595    pub fn state(&self) -> ConnectionState {
596        self.state_tx.borrow().clone()
597    }
598
599    /// Get all active stream IDs.
600    pub async fn active_stream_ids(&self) -> Vec<u32> {
601        let streams = self.resumable_streams.read().await;
602        streams.keys().copied().collect()
603    }
604
605    /// Stop all active streams gracefully.
606    ///
607    /// Sends StopStream for each active stream and clears the stream map.
608    #[tracing::instrument(skip_all)]
609    pub async fn stop_all_streams(&self) {
610        use titan_api_types::ws::v1::StopStreamRequest;
611
612        let stream_ids = self.active_stream_ids().await;
613
614        if stream_ids.is_empty() {
615            return;
616        }
617
618        tracing::info!("Stopping {} active streams", stream_ids.len());
619
620        for stream_id in stream_ids {
621            // Send stop request (fire and forget)
622            let _ = self
623                .send_request(RequestData::StopStream(StopStreamRequest { id: stream_id }))
624                .await;
625        }
626
627        // Clear all streams
628        let mut streams = self.resumable_streams.write().await;
629        streams.clear();
630    }
631
632    /// Graceful shutdown: stop all streams and signal connection loop to exit.
633    #[tracing::instrument(skip_all)]
634    pub async fn shutdown(&self) {
635        // Stop all streams first
636        self.stop_all_streams().await;
637
638        // Update state
639        let _ = self.state_tx.send(ConnectionState::Disconnected {
640            reason: "Client shutdown".to_string(),
641        });
642
643        // The connection loop will exit when it detects the sender is closed
644        // (which happens when Connection is dropped)
645    }
646}
647
648/// Calculate exponential backoff.
649fn calculate_backoff(attempt: u32, max_delay_ms: u64) -> u64 {
650    let base_delay = INITIAL_BACKOFF_MS * 2u64.saturating_pow(attempt.saturating_sub(1));
651    base_delay.min(max_delay_ms)
652}