1use std::collections::HashMap;
4use std::sync::atomic::{AtomicU32, Ordering};
5use std::sync::Arc;
6use std::time::Duration;
7
8use futures_util::{SinkExt, StreamExt};
9use titan_api_codec::codec::ws::v1::ClientCodec;
10use titan_api_codec::codec::Codec;
11use titan_api_types::ws::v1::{
12 ClientRequest, RequestData, ResponseError, ResponseSuccess, ServerMessage, StreamData,
13 SwapQuoteRequest,
14};
15use tokio::net::TcpStream;
16use tokio::sync::{mpsc, oneshot, RwLock};
17use tokio_tungstenite::tungstenite::Message;
18use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
19
20use crate::config::TitanConfig;
21use crate::error::TitanClientError;
22use crate::state::ConnectionState;
23
24type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
25type ResponseResult = Result<ResponseSuccess, ResponseError>;
26type PendingRequestsMap = Arc<RwLock<HashMap<u32, oneshot::Sender<ResponseResult>>>>;
27
28const INITIAL_BACKOFF_MS: u64 = 100;
30
31#[derive(Clone)]
33pub struct ResumableStream {
34 pub request: SwapQuoteRequest,
36 pub sender: mpsc::Sender<StreamData>,
38}
39
40type ResumableStreamsMap = Arc<RwLock<HashMap<u32, ResumableStream>>>;
41
42pub struct PendingRequest {
44 pub request: ClientRequest,
45 pub response_tx: oneshot::Sender<ResponseResult>,
46}
47
48pub struct Connection {
50 #[allow(dead_code)]
51 config: TitanConfig,
52 request_id: AtomicU32,
53 sender: mpsc::Sender<PendingRequest>,
54 state_tx: tokio::sync::watch::Sender<ConnectionState>,
55 #[allow(dead_code)]
56 pending_requests: PendingRequestsMap,
57 resumable_streams: ResumableStreamsMap,
58}
59
60impl Connection {
61 #[tracing::instrument(skip_all)]
65 pub async fn connect(config: TitanConfig) -> Result<Self, TitanClientError> {
66 let (state_tx, _state_rx) = tokio::sync::watch::channel(ConnectionState::Disconnected {
67 reason: "Connecting...".to_string(),
68 });
69
70 let pending_requests: PendingRequestsMap = Arc::new(RwLock::new(HashMap::new()));
71 let resumable_streams: ResumableStreamsMap = Arc::new(RwLock::new(HashMap::new()));
72
73 let ws_stream = Self::establish_connection(&config).await?;
75
76 let (sender, receiver) = mpsc::channel::<PendingRequest>(32);
78
79 let pending_clone = pending_requests.clone();
81 let streams_clone = resumable_streams.clone();
82 let state_tx_clone = state_tx.clone();
83 let config_clone = config.clone();
84
85 tokio::spawn(Self::run_connection_loop_with_reconnect(
86 ws_stream,
87 receiver,
88 pending_clone,
89 streams_clone,
90 state_tx_clone,
91 config_clone,
92 ));
93
94 state_tx.send_replace(ConnectionState::Connected);
95
96 Ok(Self {
97 config,
98 request_id: AtomicU32::new(1),
99 sender,
100 state_tx,
101 pending_requests,
102 resumable_streams,
103 })
104 }
105
106 async fn establish_connection(config: &TitanConfig) -> Result<WsStream, TitanClientError> {
108 use tokio_tungstenite::tungstenite::client::IntoClientRequest;
109
110 let url = if config.url.contains("/ws") || config.url.ends_with('/') {
113 format!("{}?auth={}", config.url, config.token)
115 } else {
116 format!("{}/?auth={}", config.url, config.token)
118 };
119
120 let mut request = url.into_client_request().map_err(|e| {
121 TitanClientError::Unexpected(anyhow::anyhow!("Failed to build request: {}", e))
122 })?;
123
124 request.headers_mut().insert(
126 "Sec-WebSocket-Protocol",
127 titan_api_types::ws::v1::WEBSOCKET_SUBPROTO_BASE
128 .parse()
129 .unwrap(),
130 );
131
132 let (ws_stream, _response) = tokio_tungstenite::connect_async(request)
134 .await
135 .map_err(TitanClientError::WebSocket)?;
136
137 Ok(ws_stream)
138 }
139
140 async fn run_connection_loop_with_reconnect(
142 initial_ws_stream: WsStream,
143 mut request_rx: mpsc::Receiver<PendingRequest>,
144 pending_requests: PendingRequestsMap,
145 resumable_streams: ResumableStreamsMap,
146 state_tx: tokio::sync::watch::Sender<ConnectionState>,
147 config: TitanConfig,
148 ) {
149 let mut ws_stream = initial_ws_stream;
150 let mut reconnect_attempt: u32 = 0;
151 let mut request_id_counter: u32 = 1;
152
153 loop {
154 let disconnect_reason = Self::run_single_connection(
156 &mut ws_stream,
157 &mut request_rx,
158 &pending_requests,
159 &resumable_streams,
160 &state_tx,
161 &mut request_id_counter,
162 )
163 .await;
164
165 if request_rx.is_closed() {
167 tracing::info!("Request channel closed, shutting down connection");
168 break;
169 }
170
171 reconnect_attempt += 1;
173
174 if let Some(max) = config.max_reconnect_attempts {
176 if reconnect_attempt > max {
177 tracing::error!("Max reconnect attempts ({}) reached, giving up", max);
178 let _ = state_tx.send(ConnectionState::Disconnected {
179 reason: format!(
180 "Max reconnect attempts reached. Last error: {}",
181 disconnect_reason
182 ),
183 });
184 break;
185 }
186 }
187
188 let backoff_ms = calculate_backoff(reconnect_attempt, config.max_reconnect_delay_ms);
190
191 tracing::info!(
192 attempt = reconnect_attempt,
193 backoff_ms,
194 "Reconnecting after disconnection: {}",
195 disconnect_reason
196 );
197
198 let _ = state_tx.send(ConnectionState::Reconnecting {
199 attempt: reconnect_attempt,
200 });
201
202 tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
204
205 match Self::establish_connection(&config).await {
207 Ok(new_stream) => {
208 ws_stream = new_stream;
209 reconnect_attempt = 0;
210 let _ = state_tx.send(ConnectionState::Connected);
211 tracing::info!("Reconnected successfully");
212
213 Self::resume_streams(
215 &mut ws_stream,
216 &resumable_streams,
217 &mut request_id_counter,
218 )
219 .await;
220 }
221 Err(e) => {
222 tracing::warn!("Reconnection failed: {}", e);
223 continue;
224 }
225 }
226 }
227
228 Self::cleanup_pending_requests(&pending_requests).await;
230 }
231
232 async fn resume_streams(
234 ws_stream: &mut WsStream,
235 resumable_streams: &ResumableStreamsMap,
236 request_id_counter: &mut u32,
237 ) {
238 let streams_to_resume: Vec<(u32, ResumableStream)> = {
239 let streams = resumable_streams.read().await;
240 streams.iter().map(|(k, v)| (*k, v.clone())).collect()
241 };
242
243 if streams_to_resume.is_empty() {
244 return;
245 }
246
247 tracing::info!(
248 "Resuming {} streams after reconnection",
249 streams_to_resume.len()
250 );
251
252 let codec = ClientCodec::Uncompressed;
253 let mut encoder = codec.encoder();
254 let mut decoder = codec.decoder();
255
256 for (old_stream_id, resumable) in streams_to_resume {
257 let request_id = *request_id_counter;
258 *request_id_counter += 1;
259
260 let request = ClientRequest {
261 id: request_id,
262 data: RequestData::NewSwapQuoteStream(resumable.request.clone()),
263 };
264
265 let encoded = match encoder.encode_mut(&request) {
267 Ok(data) => data.to_vec(),
268 Err(e) => {
269 tracing::error!("Failed to encode stream resume request: {}", e);
270 continue;
271 }
272 };
273
274 if let Err(e) = ws_stream.send(Message::Binary(encoded.into())).await {
275 tracing::error!("Failed to send stream resume request: {}", e);
276 continue;
277 }
278
279 match ws_stream.next().await {
281 Some(Ok(Message::Binary(data))) => {
282 match decoder.decode_mut(data) {
283 Ok(ServerMessage::Response(response)) => {
284 if let Some(stream_info) = response.stream {
285 let new_stream_id = stream_info.id;
286
287 let mut streams = resumable_streams.write().await;
289 if let Some(stream) = streams.remove(&old_stream_id) {
290 streams.insert(new_stream_id, stream);
291 tracing::info!(
292 old_id = old_stream_id,
293 new_id = new_stream_id,
294 "Stream resumed with new ID"
295 );
296 }
297 }
298 }
299 Ok(ServerMessage::Error(error)) => {
300 tracing::error!(
301 "Failed to resume stream {}: {}",
302 old_stream_id,
303 error.message
304 );
305 let mut streams = resumable_streams.write().await;
307 streams.remove(&old_stream_id);
308 }
309 Ok(_) => {
310 tracing::warn!("Unexpected response type during stream resumption");
311 }
312 Err(e) => {
313 tracing::error!("Failed to decode stream resume response: {}", e);
314 }
315 }
316 }
317 Some(Ok(_)) => {
318 tracing::warn!("Unexpected message type during stream resumption");
319 }
320 Some(Err(e)) => {
321 tracing::error!("WebSocket error during stream resumption: {}", e);
322 break;
323 }
324 None => {
325 tracing::error!("Connection closed during stream resumption");
326 break;
327 }
328 }
329 }
330 }
331
332 async fn run_single_connection(
334 ws_stream: &mut WsStream,
335 request_rx: &mut mpsc::Receiver<PendingRequest>,
336 pending_requests: &PendingRequestsMap,
337 resumable_streams: &ResumableStreamsMap,
338 state_tx: &tokio::sync::watch::Sender<ConnectionState>,
339 request_id_counter: &mut u32,
340 ) -> String {
341 let codec = ClientCodec::Uncompressed;
342 let mut encoder = codec.encoder();
343 let mut decoder = codec.decoder();
344
345 let (mut ws_sink, mut ws_stream_rx) = ws_stream.split();
346
347 loop {
348 tokio::select! {
349 Some(pending_req) = request_rx.recv() => {
350 let request_id = pending_req.request.id;
351 *request_id_counter = request_id.max(*request_id_counter) + 1;
352
353 {
354 let mut pending_map = pending_requests.write().await;
355 pending_map.insert(request_id, pending_req.response_tx);
356 }
357
358 match encoder.encode_mut(&pending_req.request) {
359 Ok(data) => {
360 if let Err(e) = ws_sink.send(Message::Binary(data.to_vec().into())).await {
361 tracing::error!("Failed to send WebSocket message: {}", e);
362 let mut pending_map = pending_requests.write().await;
363 if let Some(tx) = pending_map.remove(&request_id) {
364 let _ = tx.send(Err(ResponseError {
365 request_id,
366 code: 0,
367 message: format!("Send failed: {}", e),
368 }));
369 }
370 }
371 }
372 Err(e) => {
373 tracing::error!("Failed to encode request: {}", e);
374 let mut pending_map = pending_requests.write().await;
375 if let Some(tx) = pending_map.remove(&request_id) {
376 let _ = tx.send(Err(ResponseError {
377 request_id,
378 code: 0,
379 message: format!("Encode failed: {}", e),
380 }));
381 }
382 }
383 }
384 }
385
386 Some(msg_result) = ws_stream_rx.next() => {
387 match msg_result {
388 Ok(Message::Binary(data)) => {
389 match decoder.decode_mut(data) {
390 Ok(server_msg) => {
391 Self::handle_server_message(
392 server_msg,
393 pending_requests,
394 resumable_streams,
395 ).await;
396 }
397 Err(e) => {
398 tracing::error!("Failed to decode server message: {}", e);
399 }
400 }
401 }
402 Ok(Message::Close(frame)) => {
403 let reason = frame
404 .map(|f| f.reason.to_string())
405 .unwrap_or_else(|| "Server closed connection".to_string());
406 tracing::warn!("WebSocket closed: {}", reason);
407 let _ = state_tx.send(ConnectionState::Disconnected {
408 reason: reason.clone(),
409 });
410 return reason;
411 }
412 Ok(Message::Ping(data)) => {
413 let _ = ws_sink.send(Message::Pong(data)).await;
414 }
415 Ok(_) => {}
416 Err(e) => {
417 let reason = format!("WebSocket error: {}", e);
418 let error_str = e.to_string();
419 if error_str.contains("Connection reset without closing handshake") {
420 tracing::info!("{}", reason);
421 } else {
422 tracing::error!("{}", reason);
423 }
424 let _ = state_tx.send(ConnectionState::Disconnected {
425 reason: reason.clone(),
426 });
427 return reason;
428 }
429 }
430 }
431
432 else => {
433 return "Channel closed".to_string();
434 }
435 }
436 }
437 }
438
439 async fn handle_server_message(
441 msg: ServerMessage,
442 pending_requests: &PendingRequestsMap,
443 resumable_streams: &ResumableStreamsMap,
444 ) {
445 match msg {
446 ServerMessage::Response(response) => {
447 let mut pending = pending_requests.write().await;
448 if let Some(tx) = pending.remove(&response.request_id) {
449 let _ = tx.send(Ok(response));
450 }
451 }
452 ServerMessage::Error(error) => {
453 let mut pending = pending_requests.write().await;
454 if let Some(tx) = pending.remove(&error.request_id) {
455 let _ = tx.send(Err(error));
456 }
457 }
458 ServerMessage::StreamData(data) => {
459 let streams = resumable_streams.read().await;
460 if let Some(stream) = streams.get(&data.id) {
461 let _ = stream.sender.send(data).await;
462 }
463 }
464 ServerMessage::StreamEnd(end) => {
465 let mut streams = resumable_streams.write().await;
466 streams.remove(&end.id);
467 }
468 ServerMessage::Other(_) => {
469 tracing::warn!("Received unknown server message type");
470 }
471 }
472 }
473
474 async fn cleanup_pending_requests(pending_requests: &PendingRequestsMap) {
476 let mut pending_map = pending_requests.write().await;
477 for (request_id, tx) in pending_map.drain() {
478 let _ = tx.send(Err(ResponseError {
479 request_id,
480 code: 0,
481 message: "Connection closed".to_string(),
482 }));
483 }
484 }
485
486 #[tracing::instrument(skip_all)]
488 pub async fn send_request(
489 &self,
490 data: RequestData,
491 ) -> Result<ResponseSuccess, TitanClientError> {
492 let request_id = self.request_id.fetch_add(1, Ordering::SeqCst);
493 let request = ClientRequest {
494 id: request_id,
495 data,
496 };
497
498 let (response_tx, response_rx) = oneshot::channel();
499
500 self.sender
501 .send(PendingRequest {
502 request,
503 response_tx,
504 })
505 .await
506 .map_err(|_| TitanClientError::Unexpected(anyhow::anyhow!("Connection closed")))?;
507
508 let response = response_rx.await.map_err(|_| {
509 TitanClientError::Unexpected(anyhow::anyhow!("Response channel closed"))
510 })?;
511
512 response.map_err(|e| TitanClientError::ServerError {
513 code: e.code,
514 message: e.message,
515 })
516 }
517
518 pub async fn register_stream(
520 &self,
521 stream_id: u32,
522 request: SwapQuoteRequest,
523 sender: mpsc::Sender<StreamData>,
524 ) {
525 let mut streams = self.resumable_streams.write().await;
526 streams.insert(stream_id, ResumableStream { request, sender });
527 }
528
529 pub async fn unregister_stream(&self, stream_id: u32) {
531 let mut streams = self.resumable_streams.write().await;
532 streams.remove(&stream_id);
533 }
534
535 pub fn state_receiver(&self) -> tokio::sync::watch::Receiver<ConnectionState> {
537 self.state_tx.subscribe()
538 }
539
540 pub fn state(&self) -> ConnectionState {
542 self.state_tx.borrow().clone()
543 }
544
545 pub async fn active_stream_ids(&self) -> Vec<u32> {
547 let streams = self.resumable_streams.read().await;
548 streams.keys().copied().collect()
549 }
550
551 #[tracing::instrument(skip_all)]
555 pub async fn stop_all_streams(&self) {
556 use titan_api_types::ws::v1::StopStreamRequest;
557
558 let stream_ids = self.active_stream_ids().await;
559
560 if stream_ids.is_empty() {
561 return;
562 }
563
564 tracing::info!("Stopping {} active streams", stream_ids.len());
565
566 for stream_id in stream_ids {
567 let _ = self
569 .send_request(RequestData::StopStream(StopStreamRequest { id: stream_id }))
570 .await;
571 }
572
573 let mut streams = self.resumable_streams.write().await;
575 streams.clear();
576 }
577
578 #[tracing::instrument(skip_all)]
580 pub async fn shutdown(&self) {
581 self.stop_all_streams().await;
583
584 let _ = self.state_tx.send(ConnectionState::Disconnected {
586 reason: "Client shutdown".to_string(),
587 });
588
589 }
592}
593
594fn calculate_backoff(attempt: u32, max_delay_ms: u64) -> u64 {
596 let base_delay = INITIAL_BACKOFF_MS * 2u64.saturating_pow(attempt.saturating_sub(1));
597 base_delay.min(max_delay_ms)
598}