Skip to main content

titan_rust_client/
connection.rs

1//! WebSocket connection management with auto-reconnect and stream resumption.
2
3use std::collections::HashMap;
4use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
5use std::sync::Arc;
6use std::time::Duration;
7
8use futures_util::{SinkExt, StreamExt};
9use titan_api_codec::codec::ws::v1::ClientCodec;
10use titan_api_codec::codec::Codec;
11use titan_api_types::ws::v1::{
12    ClientRequest, RequestData, ResponseSuccess, ServerMessage, StreamData, SwapQuoteRequest,
13};
14use tokio::net::TcpStream;
15use tokio::sync::{mpsc, oneshot, RwLock};
16use tokio_tungstenite::tungstenite::Message;
17use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
18use tokio_util::sync::CancellationToken;
19
20use crate::config::TitanConfig;
21use crate::error::TitanClientError;
22use crate::state::ConnectionState;
23
24type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
25type ResponseResult = Result<ResponseSuccess, TitanClientError>;
26type PendingRequestsMap = Arc<RwLock<HashMap<u32, oneshot::Sender<ResponseResult>>>>;
27type OnEndCallback = Arc<dyn Fn() + Send + Sync>;
28
29/// Initial backoff delay in milliseconds.
30pub const INITIAL_BACKOFF_MS: u64 = 100;
31
32/// Default ping interval in milliseconds (used if config value is 0).
33pub const DEFAULT_PING_INTERVAL_MS: u64 = 25_000;
34
35/// Default pong timeout in milliseconds — if no pong received within this window, reconnect.
36pub const DEFAULT_PONG_TIMEOUT_MS: u64 = 10_000;
37
38/// Information needed to resume a stream after reconnection.
39#[derive(Clone)]
40pub struct ResumableStream {
41    /// The original request used to create the stream.
42    pub request: SwapQuoteRequest,
43    /// Channel to send stream data to.
44    pub sender: mpsc::Sender<StreamData>,
45    /// Called when this stream ends (server-initiated or reconnect failure) to release the slot.
46    pub on_end: Option<OnEndCallback>,
47    /// Shared atomic that tracks the current server-side stream ID (updated on reconnect remap).
48    pub effective_id: Option<Arc<AtomicU32>>,
49    /// Shared flag indicating the stream has been stopped or dropped by the client.
50    pub stopped: Arc<AtomicBool>,
51}
52
53type ResumableStreamsMap = Arc<RwLock<HashMap<u32, ResumableStream>>>;
54
55/// Internal message for sending requests through the connection
56pub struct PendingRequest {
57    pub request: ClientRequest,
58    pub response_tx: oneshot::Sender<ResponseResult>,
59}
60
61/// Manages a WebSocket connection to the Titan API with auto-reconnect.
62pub struct Connection {
63    #[expect(dead_code)]
64    config: TitanConfig,
65    request_id: AtomicU32,
66    sender: mpsc::Sender<PendingRequest>,
67    shutdown: CancellationToken,
68    state_tx: tokio::sync::watch::Sender<ConnectionState>,
69    #[expect(dead_code)]
70    pending_requests: PendingRequestsMap,
71    resumable_streams: ResumableStreamsMap,
72}
73
74struct RunSingleConnectionArgs<'a> {
75    ws_stream: &'a mut WsStream,
76    request_rx: &'a mut mpsc::Receiver<PendingRequest>,
77    pending_requests: &'a PendingRequestsMap,
78    resumable_streams: &'a ResumableStreamsMap,
79    state_tx: &'a tokio::sync::watch::Sender<ConnectionState>,
80    request_id_counter: &'a mut u32,
81    config: &'a TitanConfig,
82    shutdown: &'a CancellationToken,
83}
84
85impl Connection {
86    /// Create a new connection with the given config.
87    ///
88    /// Connects eagerly and auto-reconnects on disconnection.
89    #[tracing::instrument(skip_all)]
90    pub async fn connect(config: TitanConfig) -> Result<Self, TitanClientError> {
91        let (state_tx, _state_rx) = tokio::sync::watch::channel(ConnectionState::Disconnected {
92            reason: "Connecting...".to_string(),
93        });
94
95        let pending_requests: PendingRequestsMap = Arc::new(RwLock::new(HashMap::new()));
96        let resumable_streams: ResumableStreamsMap = Arc::new(RwLock::new(HashMap::new()));
97        let shutdown = CancellationToken::new();
98
99        // Connect to WebSocket
100        let ws_stream = Self::establish_connection(&config).await?;
101
102        // Create channel for sending requests
103        let (sender, receiver) = mpsc::channel::<PendingRequest>(32);
104
105        // Spawn background task with reconnection support
106        let pending_clone = pending_requests.clone();
107        let streams_clone = resumable_streams.clone();
108        let state_tx_clone = state_tx.clone();
109        let config_clone = config.clone();
110
111        tokio::spawn(Self::run_connection_loop_with_reconnect(
112            ws_stream,
113            receiver,
114            pending_clone,
115            streams_clone,
116            state_tx_clone,
117            config_clone,
118            shutdown.clone(),
119        ));
120
121        state_tx.send_replace(ConnectionState::Connected);
122
123        Ok(Self {
124            config,
125            request_id: AtomicU32::new(1),
126            sender,
127            shutdown,
128            state_tx,
129            pending_requests,
130            resumable_streams,
131        })
132    }
133
134    /// Establish WebSocket connection with authentication.
135    async fn establish_connection(config: &TitanConfig) -> Result<WsStream, TitanClientError> {
136        use tokio_tungstenite::tungstenite::client::IntoClientRequest;
137        use tokio_tungstenite::Connector;
138
139        let url = if config.url.contains("/ws") || config.url.ends_with('/') {
140            format!("{}?auth={}", config.url, config.token)
141        } else {
142            format!("{}/?auth={}", config.url, config.token)
143        };
144
145        let mut request = url.into_client_request().map_err(|e| {
146            TitanClientError::Unexpected(anyhow::anyhow!("Failed to build request: {e}"))
147        })?;
148
149        request.headers_mut().insert(
150            "Sec-WebSocket-Protocol",
151            titan_api_types::ws::v1::WEBSOCKET_SUBPROTO_BASE
152                .parse()
153                .map_err(|e| {
154                    TitanClientError::Unexpected(anyhow::anyhow!(
155                        "Sec-WebSocket-Protocol fail: {e}"
156                    ))
157                })?,
158        );
159
160        let tls_config = if config.danger_accept_invalid_certs {
161            crate::tls::build_dangerous_tls_config()
162        } else {
163            crate::tls::build_default_tls_config()
164        }
165        .map_err(|e| TitanClientError::Unexpected(anyhow::anyhow!("TLS config failed: {e}")))?;
166        let connector = Connector::Rustls(Arc::new(tls_config));
167        let (ws_stream, _response) =
168            tokio_tungstenite::connect_async_tls_with_config(request, None, false, Some(connector))
169                .await
170                .map_err(TitanClientError::WebSocket)?;
171        Ok(ws_stream)
172    }
173
174    /// Connection loop with automatic reconnection and stream resumption.
175    async fn run_connection_loop_with_reconnect(
176        initial_ws_stream: WsStream,
177        mut request_rx: mpsc::Receiver<PendingRequest>,
178        pending_requests: PendingRequestsMap,
179        resumable_streams: ResumableStreamsMap,
180        state_tx: tokio::sync::watch::Sender<ConnectionState>,
181        config: TitanConfig,
182        shutdown: CancellationToken,
183    ) {
184        let mut ws_stream = initial_ws_stream;
185        let mut reconnect_attempt: u32 = 0;
186        let mut request_id_counter: u32 = 1;
187
188        loop {
189            // Run the connection loop until disconnection
190            let disconnect_reason = Self::run_single_connection(RunSingleConnectionArgs {
191                ws_stream: &mut ws_stream,
192                request_rx: &mut request_rx,
193                pending_requests: &pending_requests,
194                resumable_streams: &resumable_streams,
195                state_tx: &state_tx,
196                request_id_counter: &mut request_id_counter,
197                config: &config,
198                shutdown: &shutdown,
199            })
200            .await;
201
202            // Fail all pending requests from this connection immediately
203            Self::fail_pending_requests(&pending_requests, &disconnect_reason).await;
204
205            if shutdown.is_cancelled() {
206                break;
207            }
208
209            // Check if request channel is closed (client dropped)
210            if request_rx.is_closed() {
211                tracing::info!("Request channel closed, shutting down connection");
212                break;
213            }
214
215            // Start reconnection attempts
216            reconnect_attempt += 1;
217
218            // Check max attempts
219            if let Some(max) = config.max_reconnect_attempts {
220                if reconnect_attempt > max {
221                    tracing::error!("Max reconnect attempts ({}) reached, giving up", max);
222                    let _ = state_tx.send(ConnectionState::Disconnected {
223                        reason: format!(
224                            "Max reconnect attempts reached. Last error: {}",
225                            disconnect_reason
226                        ),
227                    });
228                    break;
229                }
230            }
231
232            // Calculate backoff delay with exponential increase
233            let backoff_ms = calculate_backoff(reconnect_attempt, config.max_reconnect_delay_ms);
234
235            tracing::debug!(
236                attempt = reconnect_attempt,
237                backoff_ms,
238                "Reconnecting after disconnection: {}",
239                disconnect_reason
240            );
241
242            let _ = state_tx.send(ConnectionState::Reconnecting {
243                attempt: reconnect_attempt,
244            });
245
246            // Wait before reconnecting
247            tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
248
249            // Attempt to reconnect
250            match Self::establish_connection(&config).await {
251                Ok(new_stream) => {
252                    ws_stream = new_stream;
253                    reconnect_attempt = 0;
254                    let _ = state_tx.send(ConnectionState::Connected);
255                    tracing::debug!("Reconnected successfully");
256
257                    // Resume streams after reconnection
258                    Self::resume_streams(
259                        &mut ws_stream,
260                        &pending_requests,
261                        &resumable_streams,
262                        &mut request_id_counter,
263                    )
264                    .await;
265                }
266                Err(e) => {
267                    tracing::warn!("Reconnection failed: {}", e);
268                }
269            }
270        }
271
272        // Final cleanup
273        Self::cleanup_pending_requests(&pending_requests).await;
274        Self::cleanup_resumable_streams(&resumable_streams).await;
275    }
276
277    /// Resume all active streams after reconnection.
278    async fn resume_streams(
279        ws_stream: &mut WsStream,
280        pending_requests: &PendingRequestsMap,
281        resumable_streams: &ResumableStreamsMap,
282        request_id_counter: &mut u32,
283    ) {
284        let streams_to_resume: Vec<(u32, ResumableStream)> = {
285            let streams = resumable_streams.read().await;
286            streams.iter().map(|(k, v)| (*k, v.clone())).collect()
287        };
288
289        if streams_to_resume.is_empty() {
290            return;
291        }
292
293        tracing::info!(
294            "Resuming {} streams after reconnection",
295            streams_to_resume.len()
296        );
297
298        let codec = ClientCodec::Uncompressed;
299        let mut encoder = codec.encoder();
300        let mut decoder = codec.decoder();
301
302        for (old_stream_id, resumable) in streams_to_resume {
303            if resumable.stopped.load(Ordering::SeqCst) || resumable.sender.is_closed() {
304                let mut streams = resumable_streams.write().await;
305                if let Some(stream) = streams.remove(&old_stream_id) {
306                    if let Some(ref on_end) = stream.on_end {
307                        on_end();
308                    }
309                }
310                continue;
311            }
312
313            let request_id = *request_id_counter;
314            *request_id_counter += 1;
315
316            let request = ClientRequest {
317                id: request_id,
318                data: RequestData::NewSwapQuoteStream(resumable.request.clone()),
319            };
320
321            // Encode and send the request
322            let encoded = match encoder.encode_mut(&request) {
323                Ok(data) => data.to_vec(),
324                Err(e) => {
325                    tracing::error!("Failed to encode stream resume request: {}", e);
326                    let mut streams = resumable_streams.write().await;
327                    if let Some(stream) = streams.remove(&old_stream_id) {
328                        if let Some(ref on_end) = stream.on_end {
329                            on_end();
330                        }
331                    }
332                    continue;
333                }
334            };
335
336            if let Err(e) = ws_stream.send(Message::Binary(encoded.into())).await {
337                tracing::error!("Failed to send stream resume request: {}", e);
338                let mut streams = resumable_streams.write().await;
339                if let Some(stream) = streams.remove(&old_stream_id) {
340                    if let Some(ref on_end) = stream.on_end {
341                        on_end();
342                    }
343                }
344                continue;
345            }
346
347            // Wait for response to get new stream ID (skip interleaved frames)
348            loop {
349                match ws_stream.next().await {
350                    Some(Ok(Message::Binary(data))) => match decoder.decode_mut(data) {
351                        Ok(ServerMessage::Response(response)) => {
352                            if response.request_id != request_id {
353                                Self::handle_server_message(
354                                    ServerMessage::Response(response),
355                                    pending_requests,
356                                    resumable_streams,
357                                )
358                                .await;
359                                continue;
360                            }
361
362                            if let Some(stream_info) = response.stream {
363                                let new_stream_id = stream_info.id;
364
365                                // Update the stream mapping
366                                let mut streams = resumable_streams.write().await;
367                                if let Some(stream) = streams.remove(&old_stream_id) {
368                                    if stream.stopped.load(Ordering::SeqCst)
369                                        || stream.sender.is_closed()
370                                    {
371                                        if let Some(ref on_end) = stream.on_end {
372                                            on_end();
373                                        }
374                                    } else {
375                                        // Update the shared effective_id so QuoteStream uses the new ID
376                                        if let Some(ref effective_id) = stream.effective_id {
377                                            effective_id.store(new_stream_id, Ordering::SeqCst);
378                                        }
379                                        streams.insert(new_stream_id, stream);
380                                        tracing::info!(
381                                            old_id = old_stream_id,
382                                            new_id = new_stream_id,
383                                            "Stream resumed with new ID"
384                                        );
385                                    }
386                                }
387                            } else {
388                                tracing::error!(
389                                    "Stream resume response missing stream info for {}",
390                                    old_stream_id
391                                );
392                                let mut streams = resumable_streams.write().await;
393                                if let Some(stream) = streams.remove(&old_stream_id) {
394                                    if let Some(ref on_end) = stream.on_end {
395                                        on_end();
396                                    }
397                                }
398                            }
399                            break;
400                        }
401                        Ok(ServerMessage::Error(error)) => {
402                            if error.request_id != request_id {
403                                Self::handle_server_message(
404                                    ServerMessage::Error(error),
405                                    pending_requests,
406                                    resumable_streams,
407                                )
408                                .await;
409                                continue;
410                            }
411
412                            tracing::error!(
413                                "Failed to resume stream {}: {}",
414                                old_stream_id,
415                                error.message
416                            );
417                            // Remove the failed stream and release its slot
418                            let mut streams = resumable_streams.write().await;
419                            if let Some(stream) = streams.remove(&old_stream_id) {
420                                if let Some(ref on_end) = stream.on_end {
421                                    on_end();
422                                }
423                            }
424                            break;
425                        }
426                        Ok(other) => {
427                            Self::handle_server_message(other, pending_requests, resumable_streams)
428                                .await;
429                        }
430                        Err(e) => {
431                            tracing::error!("Failed to decode stream resume response: {}", e);
432                        }
433                    },
434                    Some(Ok(Message::Ping(data))) => {
435                        let _ = ws_stream.send(Message::Pong(data)).await;
436                    }
437                    Some(Ok(Message::Pong(_))) => {}
438                    Some(Ok(Message::Close(frame))) => {
439                        let reason = frame.map_or_else(
440                            || "Server closed connection".to_string(),
441                            |f| f.reason.to_string(),
442                        );
443                        tracing::warn!("WebSocket closed during stream resumption: {reason}");
444                        break;
445                    }
446                    Some(Ok(_)) => {}
447                    Some(Err(e)) => {
448                        tracing::error!("WebSocket error during stream resumption: {}", e);
449                        break;
450                    }
451                    None => {
452                        tracing::error!("Connection closed during stream resumption");
453                        break;
454                    }
455                }
456            }
457        }
458    }
459
460    /// Run a single connection until disconnection.
461    async fn run_single_connection(args: RunSingleConnectionArgs<'_>) -> String {
462        let RunSingleConnectionArgs {
463            ws_stream,
464            request_rx,
465            pending_requests,
466            resumable_streams,
467            state_tx,
468            request_id_counter,
469            config,
470            shutdown,
471        } = args;
472        let codec = ClientCodec::Uncompressed;
473        let mut encoder = codec.encoder();
474        let mut decoder = codec.decoder();
475
476        let (mut ws_sink, mut ws_stream_rx) = ws_stream.split();
477
478        let ping_interval_ms = if config.ping_interval_ms > 0 {
479            config.ping_interval_ms
480        } else {
481            DEFAULT_PING_INTERVAL_MS
482        };
483        let mut ping_timer = tokio::time::interval(Duration::from_millis(ping_interval_ms));
484        ping_timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
485
486        let pong_timeout = Duration::from_millis(config.pong_timeout_ms);
487        let mut last_ping = tokio::time::Instant::now();
488        let mut awaiting_pong = false;
489
490        loop {
491            tokio::select! {
492                () = shutdown.cancelled() => {
493                    return "Client shutdown".to_string();
494                }
495                maybe_req = request_rx.recv() => {
496                    let Some(pending_req) = maybe_req else {
497                        return "Request channel closed".to_string();
498                    };
499
500                    let request_id = pending_req.request.id;
501                    *request_id_counter = request_id.max(*request_id_counter) + 1;
502
503                    {
504                        let mut pending_map = pending_requests.write().await;
505                        pending_map.insert(request_id, pending_req.response_tx);
506                    }
507
508                    match encoder.encode_mut(&pending_req.request) {
509                        Ok(data) => {
510                            if let Err(e) = ws_sink.send(Message::Binary(data.to_vec().into())).await {
511                                tracing::error!("Failed to send WebSocket message: {e}");
512                                let mut pending_map = pending_requests.write().await;
513                                if let Some(tx) = pending_map.remove(&request_id) {
514                                    let _ = tx.send(Err(TitanClientError::ConnectionClosed {
515                                        reason: format!("Send failed: {e}"),
516                                    }));
517                                }
518                            }
519                        }
520                        Err(e) => {
521                            tracing::error!("Failed to encode request: {e}");
522                            let mut pending_map = pending_requests.write().await;
523                            if let Some(tx) = pending_map.remove(&request_id) {
524                                let _ = tx.send(Err(TitanClientError::Unexpected(anyhow::anyhow!(
525                                    "Encode failed: {e}"
526                                ))));
527                            }
528                        }
529                    }
530                }
531
532                Some(msg_result) = ws_stream_rx.next() => {
533                    match msg_result {
534                        Ok(Message::Binary(data)) => {
535                            match decoder.decode_mut(data) {
536                                Ok(server_msg) => {
537                                    Self::handle_server_message(
538                                        server_msg,
539                                        pending_requests,
540                                        resumable_streams,
541                                    ).await;
542                                }
543                                Err(e) => {
544                                    tracing::error!("Failed to decode server message: {e}");
545                                }
546                            }
547                        }
548                        Ok(Message::Close(frame)) => {
549                            let reason = frame.map_or_else(|| "Server closed connection".to_string(), |f| f.reason.to_string());
550                            tracing::warn!("WebSocket closed: {reason}");
551                            let _ = state_tx.send(ConnectionState::Disconnected {
552                                reason: reason.clone(),
553                            });
554                            return reason;
555                        }
556                        Ok(Message::Ping(data)) => {
557                            let _ = ws_sink.send(Message::Pong(data)).await;
558                        }
559                        Ok(Message::Pong(_)) => {
560                            awaiting_pong = false;
561                            tracing::trace!("Received pong from server");
562                        }
563                        Ok(_) => {}
564                        Err(e) => {
565                            let reason = format!("WebSocket error: {e}");
566                            let error_str = e.to_string();
567                            if error_str.contains("Connection reset without closing handshake") {
568                                tracing::debug!("{reason}");
569                            } else {
570                                tracing::error!("{reason}");
571                            }
572                            let _ = state_tx.send(ConnectionState::Disconnected {
573                                reason: reason.clone(),
574                            });
575                            return reason;
576                        }
577                    }
578                }
579
580                _ = ping_timer.tick() => {
581                    if config.pong_timeout_ms > 0 && awaiting_pong && last_ping.elapsed() > pong_timeout {
582                        let reason = "Pong timeout".to_string();
583                        let timeout_ms = config.pong_timeout_ms;
584                        tracing::debug!("No pong received within {timeout_ms}ms, triggering reconnect");
585                        let _ = state_tx.send(ConnectionState::Disconnected {
586                            reason: reason.clone(),
587                        });
588                        return reason;
589                    }
590
591                    if let Err(e) = ws_sink.send(Message::Ping(vec![].into())).await {
592                        let reason = format!("Failed to send ping: {e}");
593                        tracing::warn!("{reason}");
594                        let _ = state_tx.send(ConnectionState::Disconnected {
595                            reason: reason.clone(),
596                        });
597                        return reason;
598                    }
599                    awaiting_pong = true;
600                    last_ping = tokio::time::Instant::now();
601                    tracing::trace!("Sent keepalive ping");
602                }
603
604                else => {
605                    return "Channel closed".to_string();
606                }
607            }
608        }
609    }
610
611    /// Handle a message received from the server.
612    async fn handle_server_message(
613        msg: ServerMessage,
614        pending_requests: &PendingRequestsMap,
615        resumable_streams: &ResumableStreamsMap,
616    ) {
617        match msg {
618            ServerMessage::Response(response) => {
619                let mut pending = pending_requests.write().await;
620                if let Some(tx) = pending.remove(&response.request_id) {
621                    let _ = tx.send(Ok(response));
622                }
623            }
624            ServerMessage::Error(error) => {
625                let mut pending = pending_requests.write().await;
626                if let Some(tx) = pending.remove(&error.request_id) {
627                    let _ = tx.send(Err(TitanClientError::ServerError {
628                        code: error.code,
629                        message: error.message,
630                    }));
631                }
632            }
633            ServerMessage::StreamData(data) => {
634                let streams = resumable_streams.read().await;
635                if let Some(stream) = streams.get(&data.id) {
636                    let _ = stream.sender.send(data).await;
637                }
638            }
639            ServerMessage::StreamEnd(end) => {
640                let mut streams = resumable_streams.write().await;
641                if let Some(stream) = streams.remove(&end.id) {
642                    if let Some(ref on_end) = stream.on_end {
643                        on_end();
644                    }
645                }
646            }
647            ServerMessage::Other(_) => {
648                tracing::warn!("Received unknown server message type");
649            }
650        }
651    }
652
653    /// Cleanup pending requests on final shutdown.
654    async fn cleanup_pending_requests(pending_requests: &PendingRequestsMap) {
655        Self::fail_pending_requests(pending_requests, "Connection closed").await;
656    }
657
658    async fn fail_pending_requests(pending_requests: &PendingRequestsMap, reason: &str) {
659        let mut pending_map = pending_requests.write().await;
660        for (_request_id, tx) in pending_map.drain() {
661            let _ = tx.send(Err(TitanClientError::ConnectionClosed {
662                reason: reason.to_string(),
663            }));
664        }
665    }
666
667    /// Drop all stream senders so `QuoteStream::recv()` returns `None` instead of hanging.
668    async fn cleanup_resumable_streams(resumable_streams: &ResumableStreamsMap) {
669        let mut streams = resumable_streams.write().await;
670        for (_id, stream) in streams.drain() {
671            if let Some(ref on_end) = stream.on_end {
672                on_end();
673            }
674        }
675    }
676
677    /// Send a request and wait for response.
678    #[tracing::instrument(skip_all)]
679    pub async fn send_request(
680        &self,
681        data: RequestData,
682    ) -> Result<ResponseSuccess, TitanClientError> {
683        if self.shutdown.is_cancelled() {
684            return Err(TitanClientError::ConnectionClosed {
685                reason: "Client shutdown".to_string(),
686            });
687        }
688
689        let request_id = self.request_id.fetch_add(1, Ordering::SeqCst);
690        let request = ClientRequest {
691            id: request_id,
692            data,
693        };
694
695        let (response_tx, response_rx) = oneshot::channel();
696
697        self.sender
698            .send(PendingRequest {
699                request,
700                response_tx,
701            })
702            .await
703            .map_err(|_| TitanClientError::ConnectionClosed {
704                reason: "Connection closed".to_string(),
705            })?;
706
707        let response = response_rx
708            .await
709            .map_err(|_| TitanClientError::ConnectionClosed {
710                reason: "Response channel closed".to_string(),
711            })?;
712
713        response
714    }
715
716    /// Register a resumable stream.
717    pub async fn register_stream(
718        &self,
719        stream_id: u32,
720        request: SwapQuoteRequest,
721        sender: mpsc::Sender<StreamData>,
722        on_end: Option<OnEndCallback>,
723        effective_id: Option<Arc<AtomicU32>>,
724        stopped: Arc<AtomicBool>,
725    ) {
726        let mut streams = self.resumable_streams.write().await;
727        streams.insert(
728            stream_id,
729            ResumableStream {
730                request,
731                sender,
732                on_end,
733                effective_id,
734                stopped,
735            },
736        );
737    }
738
739    /// Unregister a stream.
740    pub async fn unregister_stream(&self, stream_id: u32) {
741        let mut streams = self.resumable_streams.write().await;
742        streams.remove(&stream_id);
743    }
744
745    /// Get a receiver for connection state changes.
746    pub fn state_receiver(&self) -> tokio::sync::watch::Receiver<ConnectionState> {
747        self.state_tx.subscribe()
748    }
749
750    /// Get the current connection state.
751    pub fn state(&self) -> ConnectionState {
752        self.state_tx.borrow().clone()
753    }
754
755    /// Get all active stream IDs.
756    pub async fn active_stream_ids(&self) -> Vec<u32> {
757        let streams = self.resumable_streams.read().await;
758        streams.keys().copied().collect()
759    }
760
761    /// Stop all active streams gracefully.
762    ///
763    /// Sends StopStream for each active stream and clears the stream map.
764    #[tracing::instrument(skip_all)]
765    pub async fn stop_all_streams(&self) {
766        use titan_api_types::ws::v1::StopStreamRequest;
767
768        let stream_ids = self.active_stream_ids().await;
769
770        if stream_ids.is_empty() {
771            return;
772        }
773
774        tracing::info!("Stopping {} active streams", stream_ids.len());
775
776        for stream_id in stream_ids {
777            // Send stop request (fire and forget)
778            let _ = self
779                .send_request(RequestData::StopStream(StopStreamRequest { id: stream_id }))
780                .await;
781        }
782
783        // Clear all streams and call on_end for each
784        let mut streams = self.resumable_streams.write().await;
785        for (_id, stream) in streams.drain() {
786            if let Some(ref on_end) = stream.on_end {
787                on_end();
788            }
789        }
790    }
791
792    /// Graceful shutdown: stop all streams and signal connection loop to exit.
793    #[tracing::instrument(skip_all)]
794    pub async fn shutdown(&self) {
795        // Stop all streams first
796        self.stop_all_streams().await;
797
798        self.shutdown.cancel();
799
800        // Update state
801        let _ = self.state_tx.send(ConnectionState::Disconnected {
802            reason: "Client shutdown".to_string(),
803        });
804
805        // The connection loop will exit when it detects the sender is closed
806        // (which happens when Connection is dropped)
807    }
808}
809
810/// Calculate exponential backoff.
811fn calculate_backoff(attempt: u32, max_delay_ms: u64) -> u64 {
812    let base_delay = INITIAL_BACKOFF_MS * 2u64.saturating_pow(attempt.saturating_sub(1));
813    base_delay.min(max_delay_ms)
814}