1use std::collections::HashMap;
4use std::sync::atomic::{AtomicU32, Ordering};
5use std::sync::Arc;
6use std::time::Duration;
7
8use futures_util::{SinkExt, StreamExt};
9use titan_api_codec::codec::ws::v1::ClientCodec;
10use titan_api_codec::codec::Codec;
11use titan_api_types::ws::v1::{
12 ClientRequest, RequestData, ResponseError, ResponseSuccess, ServerMessage, StreamData,
13 SwapQuoteRequest,
14};
15use tokio::net::TcpStream;
16use tokio::sync::{mpsc, oneshot, RwLock};
17use tokio_tungstenite::tungstenite::Message;
18use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
19
20use crate::config::TitanConfig;
21use crate::error::TitanClientError;
22use crate::state::ConnectionState;
23
24type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
25type ResponseResult = Result<ResponseSuccess, ResponseError>;
26type PendingRequestsMap = Arc<RwLock<HashMap<u32, oneshot::Sender<ResponseResult>>>>;
27
28pub const INITIAL_BACKOFF_MS: u64 = 100;
30
31pub const DEFAULT_PING_INTERVAL_MS: u64 = 25_000;
33
34pub const DEFAULT_PONG_TIMEOUT_MS: u64 = 10_000;
36
37#[derive(Clone)]
39pub struct ResumableStream {
40 pub request: SwapQuoteRequest,
42 pub sender: mpsc::Sender<StreamData>,
44}
45
46type ResumableStreamsMap = Arc<RwLock<HashMap<u32, ResumableStream>>>;
47
48pub struct PendingRequest {
50 pub request: ClientRequest,
51 pub response_tx: oneshot::Sender<ResponseResult>,
52}
53
54pub struct Connection {
56 #[expect(dead_code)]
57 config: TitanConfig,
58 request_id: AtomicU32,
59 sender: mpsc::Sender<PendingRequest>,
60 state_tx: tokio::sync::watch::Sender<ConnectionState>,
61 #[expect(dead_code)]
62 pending_requests: PendingRequestsMap,
63 resumable_streams: ResumableStreamsMap,
64}
65
66impl Connection {
67 #[tracing::instrument(skip_all)]
71 pub async fn connect(config: TitanConfig) -> Result<Self, TitanClientError> {
72 let (state_tx, _state_rx) = tokio::sync::watch::channel(ConnectionState::Disconnected {
73 reason: "Connecting...".to_string(),
74 });
75
76 let pending_requests: PendingRequestsMap = Arc::new(RwLock::new(HashMap::new()));
77 let resumable_streams: ResumableStreamsMap = Arc::new(RwLock::new(HashMap::new()));
78
79 let ws_stream = Self::establish_connection(&config).await?;
81
82 let (sender, receiver) = mpsc::channel::<PendingRequest>(32);
84
85 let pending_clone = pending_requests.clone();
87 let streams_clone = resumable_streams.clone();
88 let state_tx_clone = state_tx.clone();
89 let config_clone = config.clone();
90
91 tokio::spawn(Self::run_connection_loop_with_reconnect(
92 ws_stream,
93 receiver,
94 pending_clone,
95 streams_clone,
96 state_tx_clone,
97 config_clone,
98 ));
99
100 state_tx.send_replace(ConnectionState::Connected);
101
102 Ok(Self {
103 config,
104 request_id: AtomicU32::new(1),
105 sender,
106 state_tx,
107 pending_requests,
108 resumable_streams,
109 })
110 }
111
112 async fn establish_connection(config: &TitanConfig) -> Result<WsStream, TitanClientError> {
114 use tokio_tungstenite::tungstenite::client::IntoClientRequest;
115 use tokio_tungstenite::Connector;
116
117 let url = if config.url.contains("/ws") || config.url.ends_with('/') {
120 format!("{}?auth={}", config.url, config.token)
122 } else {
123 format!("{}/?auth={}", config.url, config.token)
125 };
126
127 let mut request = url.into_client_request().map_err(|e| {
128 TitanClientError::Unexpected(anyhow::anyhow!("Failed to build request: {}", e))
129 })?;
130
131 request.headers_mut().insert(
133 "Sec-WebSocket-Protocol",
134 titan_api_types::ws::v1::WEBSOCKET_SUBPROTO_BASE
135 .parse()
136 .map_err(|e| {
137 TitanClientError::Unexpected(anyhow::anyhow!(
138 "Sec-WebSocket-Protocol fail: {e}"
139 ))
140 })?,
141 );
142
143 let (ws_stream, _response) = if config.danger_accept_invalid_certs {
144 let tls_config = crate::tls::build_dangerous_tls_config();
145 let connector = Connector::Rustls(Arc::new(tls_config));
146 tokio_tungstenite::connect_async_tls_with_config(request, None, false, Some(connector))
147 .await
148 .map_err(TitanClientError::WebSocket)?
149 } else {
150 tokio_tungstenite::connect_async(request)
151 .await
152 .map_err(TitanClientError::WebSocket)?
153 };
154
155 Ok(ws_stream)
156 }
157
158 async fn run_connection_loop_with_reconnect(
160 initial_ws_stream: WsStream,
161 mut request_rx: mpsc::Receiver<PendingRequest>,
162 pending_requests: PendingRequestsMap,
163 resumable_streams: ResumableStreamsMap,
164 state_tx: tokio::sync::watch::Sender<ConnectionState>,
165 config: TitanConfig,
166 ) {
167 let mut ws_stream = initial_ws_stream;
168 let mut reconnect_attempt: u32 = 0;
169 let mut request_id_counter: u32 = 1;
170
171 loop {
172 let disconnect_reason = Self::run_single_connection(
174 &mut ws_stream,
175 &mut request_rx,
176 &pending_requests,
177 &resumable_streams,
178 &state_tx,
179 &mut request_id_counter,
180 &config,
181 )
182 .await;
183
184 if request_rx.is_closed() {
186 tracing::info!("Request channel closed, shutting down connection");
187 break;
188 }
189
190 reconnect_attempt += 1;
192
193 if let Some(max) = config.max_reconnect_attempts {
195 if reconnect_attempt > max {
196 tracing::error!("Max reconnect attempts ({}) reached, giving up", max);
197 let _ = state_tx.send(ConnectionState::Disconnected {
198 reason: format!(
199 "Max reconnect attempts reached. Last error: {}",
200 disconnect_reason
201 ),
202 });
203 break;
204 }
205 }
206
207 let backoff_ms = calculate_backoff(reconnect_attempt, config.max_reconnect_delay_ms);
209
210 tracing::debug!(
211 attempt = reconnect_attempt,
212 backoff_ms,
213 "Reconnecting after disconnection: {}",
214 disconnect_reason
215 );
216
217 let _ = state_tx.send(ConnectionState::Reconnecting {
218 attempt: reconnect_attempt,
219 });
220
221 tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
223
224 match Self::establish_connection(&config).await {
226 Ok(new_stream) => {
227 ws_stream = new_stream;
228 reconnect_attempt = 0;
229 let _ = state_tx.send(ConnectionState::Connected);
230 tracing::debug!("Reconnected successfully");
231
232 Self::resume_streams(
234 &mut ws_stream,
235 &resumable_streams,
236 &mut request_id_counter,
237 )
238 .await;
239 }
240 Err(e) => {
241 tracing::warn!("Reconnection failed: {}", e);
242 }
243 }
244 }
245
246 Self::cleanup_pending_requests(&pending_requests).await;
248 }
249
250 async fn resume_streams(
252 ws_stream: &mut WsStream,
253 resumable_streams: &ResumableStreamsMap,
254 request_id_counter: &mut u32,
255 ) {
256 let streams_to_resume: Vec<(u32, ResumableStream)> = {
257 let streams = resumable_streams.read().await;
258 streams.iter().map(|(k, v)| (*k, v.clone())).collect()
259 };
260
261 if streams_to_resume.is_empty() {
262 return;
263 }
264
265 tracing::info!(
266 "Resuming {} streams after reconnection",
267 streams_to_resume.len()
268 );
269
270 let codec = ClientCodec::Uncompressed;
271 let mut encoder = codec.encoder();
272 let mut decoder = codec.decoder();
273
274 for (old_stream_id, resumable) in streams_to_resume {
275 let request_id = *request_id_counter;
276 *request_id_counter += 1;
277
278 let request = ClientRequest {
279 id: request_id,
280 data: RequestData::NewSwapQuoteStream(resumable.request.clone()),
281 };
282
283 let encoded = match encoder.encode_mut(&request) {
285 Ok(data) => data.to_vec(),
286 Err(e) => {
287 tracing::error!("Failed to encode stream resume request: {}", e);
288 continue;
289 }
290 };
291
292 if let Err(e) = ws_stream.send(Message::Binary(encoded.into())).await {
293 tracing::error!("Failed to send stream resume request: {}", e);
294 continue;
295 }
296
297 match ws_stream.next().await {
299 Some(Ok(Message::Binary(data))) => {
300 match decoder.decode_mut(data) {
301 Ok(ServerMessage::Response(response)) => {
302 if let Some(stream_info) = response.stream {
303 let new_stream_id = stream_info.id;
304
305 let mut streams = resumable_streams.write().await;
307 if let Some(stream) = streams.remove(&old_stream_id) {
308 streams.insert(new_stream_id, stream);
309 tracing::info!(
310 old_id = old_stream_id,
311 new_id = new_stream_id,
312 "Stream resumed with new ID"
313 );
314 }
315 }
316 }
317 Ok(ServerMessage::Error(error)) => {
318 tracing::error!(
319 "Failed to resume stream {}: {}",
320 old_stream_id,
321 error.message
322 );
323 let mut streams = resumable_streams.write().await;
325 streams.remove(&old_stream_id);
326 }
327 Ok(_) => {
328 tracing::warn!("Unexpected response type during stream resumption");
329 }
330 Err(e) => {
331 tracing::error!("Failed to decode stream resume response: {}", e);
332 }
333 }
334 }
335 Some(Ok(_)) => {
336 tracing::warn!("Unexpected message type during stream resumption");
337 }
338 Some(Err(e)) => {
339 tracing::error!("WebSocket error during stream resumption: {}", e);
340 break;
341 }
342 None => {
343 tracing::error!("Connection closed during stream resumption");
344 break;
345 }
346 }
347 }
348 }
349
350 async fn run_single_connection(
352 ws_stream: &mut WsStream,
353 request_rx: &mut mpsc::Receiver<PendingRequest>,
354 pending_requests: &PendingRequestsMap,
355 resumable_streams: &ResumableStreamsMap,
356 state_tx: &tokio::sync::watch::Sender<ConnectionState>,
357 request_id_counter: &mut u32,
358 config: &TitanConfig,
359 ) -> String {
360 let codec = ClientCodec::Uncompressed;
361 let mut encoder = codec.encoder();
362 let mut decoder = codec.decoder();
363
364 let (mut ws_sink, mut ws_stream_rx) = ws_stream.split();
365
366 let ping_interval_ms = if config.ping_interval_ms > 0 {
367 config.ping_interval_ms
368 } else {
369 DEFAULT_PING_INTERVAL_MS
370 };
371 let mut ping_timer = tokio::time::interval(Duration::from_millis(ping_interval_ms));
372 ping_timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
373
374 let pong_timeout = Duration::from_millis(config.pong_timeout_ms);
375 let mut last_pong = tokio::time::Instant::now();
376
377 loop {
378 tokio::select! {
379 Some(pending_req) = request_rx.recv() => {
380 let request_id = pending_req.request.id;
381 *request_id_counter = request_id.max(*request_id_counter) + 1;
382
383 {
384 let mut pending_map = pending_requests.write().await;
385 pending_map.insert(request_id, pending_req.response_tx);
386 }
387
388 match encoder.encode_mut(&pending_req.request) {
389 Ok(data) => {
390 if let Err(e) = ws_sink.send(Message::Binary(data.to_vec().into())).await {
391 tracing::error!("Failed to send WebSocket message: {e}");
392 let mut pending_map = pending_requests.write().await;
393 if let Some(tx) = pending_map.remove(&request_id) {
394 let _ = tx.send(Err(ResponseError {
395 request_id,
396 code: 0,
397 message: format!("Send failed: {e}"),
398 }));
399 }
400 }
401 }
402 Err(e) => {
403 tracing::error!("Failed to encode request: {e}");
404 let mut pending_map = pending_requests.write().await;
405 if let Some(tx) = pending_map.remove(&request_id) {
406 let _ = tx.send(Err(ResponseError {
407 request_id,
408 code: 0,
409 message: format!("Encode failed: {e}"),
410 }));
411 }
412 }
413 }
414 }
415
416 Some(msg_result) = ws_stream_rx.next() => {
417 match msg_result {
418 Ok(Message::Binary(data)) => {
419 match decoder.decode_mut(data) {
420 Ok(server_msg) => {
421 Self::handle_server_message(
422 server_msg,
423 pending_requests,
424 resumable_streams,
425 ).await;
426 }
427 Err(e) => {
428 tracing::error!("Failed to decode server message: {e}");
429 }
430 }
431 }
432 Ok(Message::Close(frame)) => {
433 let reason = frame.map_or_else(|| "Server closed connection".to_string(), |f| f.reason.to_string());
434 tracing::warn!("WebSocket closed: {reason}");
435 let _ = state_tx.send(ConnectionState::Disconnected {
436 reason: reason.clone(),
437 });
438 return reason;
439 }
440 Ok(Message::Ping(data)) => {
441 let _ = ws_sink.send(Message::Pong(data)).await;
442 }
443 Ok(Message::Pong(_)) => {
444 last_pong = tokio::time::Instant::now();
445 tracing::trace!("Received pong from server");
446 }
447 Ok(_) => {}
448 Err(e) => {
449 let reason = format!("WebSocket error: {e}");
450 let error_str = e.to_string();
451 if error_str.contains("Connection reset without closing handshake") {
452 tracing::debug!("{reason}");
453 } else {
454 tracing::error!("{reason}");
455 }
456 let _ = state_tx.send(ConnectionState::Disconnected {
457 reason: reason.clone(),
458 });
459 return reason;
460 }
461 }
462 }
463
464 _ = ping_timer.tick() => {
465 if config.pong_timeout_ms > 0 && last_pong.elapsed() > pong_timeout {
466 let reason = "Pong timeout".to_string();
467 let timeout_ms = config.pong_timeout_ms;
468 tracing::debug!("No pong received within {timeout_ms}ms, triggering reconnect");
469 let _ = state_tx.send(ConnectionState::Disconnected {
470 reason: reason.clone(),
471 });
472 return reason;
473 }
474
475 if let Err(e) = ws_sink.send(Message::Ping(vec![].into())).await {
476 let reason = format!("Failed to send ping: {e}");
477 tracing::warn!("{reason}");
478 let _ = state_tx.send(ConnectionState::Disconnected {
479 reason: reason.clone(),
480 });
481 return reason;
482 }
483 tracing::trace!("Sent keepalive ping");
484 }
485
486 else => {
487 return "Channel closed".to_string();
488 }
489 }
490 }
491 }
492
493 async fn handle_server_message(
495 msg: ServerMessage,
496 pending_requests: &PendingRequestsMap,
497 resumable_streams: &ResumableStreamsMap,
498 ) {
499 match msg {
500 ServerMessage::Response(response) => {
501 let mut pending = pending_requests.write().await;
502 if let Some(tx) = pending.remove(&response.request_id) {
503 let _ = tx.send(Ok(response));
504 }
505 }
506 ServerMessage::Error(error) => {
507 let mut pending = pending_requests.write().await;
508 if let Some(tx) = pending.remove(&error.request_id) {
509 let _ = tx.send(Err(error));
510 }
511 }
512 ServerMessage::StreamData(data) => {
513 let streams = resumable_streams.read().await;
514 if let Some(stream) = streams.get(&data.id) {
515 let _ = stream.sender.send(data).await;
516 }
517 }
518 ServerMessage::StreamEnd(end) => {
519 let mut streams = resumable_streams.write().await;
520 streams.remove(&end.id);
521 }
522 ServerMessage::Other(_) => {
523 tracing::warn!("Received unknown server message type");
524 }
525 }
526 }
527
528 async fn cleanup_pending_requests(pending_requests: &PendingRequestsMap) {
530 let mut pending_map = pending_requests.write().await;
531 for (request_id, tx) in pending_map.drain() {
532 let _ = tx.send(Err(ResponseError {
533 request_id,
534 code: 0,
535 message: "Connection closed".to_string(),
536 }));
537 }
538 }
539
540 #[tracing::instrument(skip_all)]
542 pub async fn send_request(
543 &self,
544 data: RequestData,
545 ) -> Result<ResponseSuccess, TitanClientError> {
546 let request_id = self.request_id.fetch_add(1, Ordering::SeqCst);
547 let request = ClientRequest {
548 id: request_id,
549 data,
550 };
551
552 let (response_tx, response_rx) = oneshot::channel();
553
554 self.sender
555 .send(PendingRequest {
556 request,
557 response_tx,
558 })
559 .await
560 .map_err(|_| TitanClientError::Unexpected(anyhow::anyhow!("Connection closed")))?;
561
562 let response = response_rx.await.map_err(|_| {
563 TitanClientError::Unexpected(anyhow::anyhow!("Response channel closed"))
564 })?;
565
566 response.map_err(|e| TitanClientError::ServerError {
567 code: e.code,
568 message: e.message,
569 })
570 }
571
572 pub async fn register_stream(
574 &self,
575 stream_id: u32,
576 request: SwapQuoteRequest,
577 sender: mpsc::Sender<StreamData>,
578 ) {
579 let mut streams = self.resumable_streams.write().await;
580 streams.insert(stream_id, ResumableStream { request, sender });
581 }
582
583 pub async fn unregister_stream(&self, stream_id: u32) {
585 let mut streams = self.resumable_streams.write().await;
586 streams.remove(&stream_id);
587 }
588
589 pub fn state_receiver(&self) -> tokio::sync::watch::Receiver<ConnectionState> {
591 self.state_tx.subscribe()
592 }
593
594 pub fn state(&self) -> ConnectionState {
596 self.state_tx.borrow().clone()
597 }
598
599 pub async fn active_stream_ids(&self) -> Vec<u32> {
601 let streams = self.resumable_streams.read().await;
602 streams.keys().copied().collect()
603 }
604
605 #[tracing::instrument(skip_all)]
609 pub async fn stop_all_streams(&self) {
610 use titan_api_types::ws::v1::StopStreamRequest;
611
612 let stream_ids = self.active_stream_ids().await;
613
614 if stream_ids.is_empty() {
615 return;
616 }
617
618 tracing::info!("Stopping {} active streams", stream_ids.len());
619
620 for stream_id in stream_ids {
621 let _ = self
623 .send_request(RequestData::StopStream(StopStreamRequest { id: stream_id }))
624 .await;
625 }
626
627 let mut streams = self.resumable_streams.write().await;
629 streams.clear();
630 }
631
632 #[tracing::instrument(skip_all)]
634 pub async fn shutdown(&self) {
635 self.stop_all_streams().await;
637
638 let _ = self.state_tx.send(ConnectionState::Disconnected {
640 reason: "Client shutdown".to_string(),
641 });
642
643 }
646}
647
648fn calculate_backoff(attempt: u32, max_delay_ms: u64) -> u64 {
650 let base_delay = INITIAL_BACKOFF_MS * 2u64.saturating_pow(attempt.saturating_sub(1));
651 base_delay.min(max_delay_ms)
652}