1use std::collections::HashMap;
4use std::sync::atomic::{AtomicU32, Ordering};
5use std::sync::Arc;
6use std::time::Duration;
7
8use futures_util::{SinkExt, StreamExt};
9use titan_api_codec::codec::ws::v1::ClientCodec;
10use titan_api_codec::codec::Codec;
11use titan_api_types::ws::v1::{
12 ClientRequest, RequestData, ResponseError, ResponseSuccess, ServerMessage, StreamData,
13 SwapQuoteRequest,
14};
15use tokio::net::TcpStream;
16use tokio::sync::{mpsc, oneshot, RwLock};
17use tokio_tungstenite::tungstenite::Message;
18use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
19
20use crate::config::TitanConfig;
21use crate::error::TitanClientError;
22use crate::state::ConnectionState;
23
24type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
25type ResponseResult = Result<ResponseSuccess, ResponseError>;
26type PendingRequestsMap = Arc<RwLock<HashMap<u32, oneshot::Sender<ResponseResult>>>>;
27
28const INITIAL_BACKOFF_MS: u64 = 100;
30
31#[derive(Clone)]
33pub struct ResumableStream {
34 pub request: SwapQuoteRequest,
36 pub sender: mpsc::Sender<StreamData>,
38}
39
40type ResumableStreamsMap = Arc<RwLock<HashMap<u32, ResumableStream>>>;
41
42pub struct PendingRequest {
44 pub request: ClientRequest,
45 pub response_tx: oneshot::Sender<ResponseResult>,
46}
47
48pub struct Connection {
50 #[allow(dead_code)]
51 config: TitanConfig,
52 request_id: AtomicU32,
53 sender: mpsc::Sender<PendingRequest>,
54 state_tx: tokio::sync::watch::Sender<ConnectionState>,
55 #[allow(dead_code)]
56 pending_requests: PendingRequestsMap,
57 resumable_streams: ResumableStreamsMap,
58}
59
60impl Connection {
61 #[tracing::instrument(skip_all)]
65 pub async fn connect(config: TitanConfig) -> Result<Self, TitanClientError> {
66 let (state_tx, _state_rx) = tokio::sync::watch::channel(ConnectionState::Disconnected {
67 reason: "Connecting...".to_string(),
68 });
69
70 let pending_requests: PendingRequestsMap = Arc::new(RwLock::new(HashMap::new()));
71 let resumable_streams: ResumableStreamsMap = Arc::new(RwLock::new(HashMap::new()));
72
73 let ws_stream = Self::establish_connection(&config).await?;
75
76 let (sender, receiver) = mpsc::channel::<PendingRequest>(32);
78
79 let pending_clone = pending_requests.clone();
81 let streams_clone = resumable_streams.clone();
82 let state_tx_clone = state_tx.clone();
83 let config_clone = config.clone();
84
85 tokio::spawn(Self::run_connection_loop_with_reconnect(
86 ws_stream,
87 receiver,
88 pending_clone,
89 streams_clone,
90 state_tx_clone,
91 config_clone,
92 ));
93
94 state_tx.send_replace(ConnectionState::Connected);
95
96 Ok(Self {
97 config,
98 request_id: AtomicU32::new(1),
99 sender,
100 state_tx,
101 pending_requests,
102 resumable_streams,
103 })
104 }
105
106 async fn establish_connection(config: &TitanConfig) -> Result<WsStream, TitanClientError> {
108 use tokio_tungstenite::tungstenite::client::IntoClientRequest;
109
110 let url = if config.url.contains("/ws") || config.url.ends_with('/') {
113 format!("{}?auth={}", config.url, config.token)
115 } else {
116 format!("{}/?auth={}", config.url, config.token)
118 };
119
120 let mut request = url.into_client_request().map_err(|e| {
121 TitanClientError::Unexpected(anyhow::anyhow!("Failed to build request: {}", e))
122 })?;
123
124 request.headers_mut().insert(
126 "Sec-WebSocket-Protocol",
127 titan_api_types::ws::v1::WEBSOCKET_SUBPROTO_BASE
128 .parse()
129 .unwrap(),
130 );
131
132 let (ws_stream, _response) = tokio_tungstenite::connect_async(request)
134 .await
135 .map_err(TitanClientError::WebSocket)?;
136
137 Ok(ws_stream)
138 }
139
140 async fn run_connection_loop_with_reconnect(
142 initial_ws_stream: WsStream,
143 mut request_rx: mpsc::Receiver<PendingRequest>,
144 pending_requests: PendingRequestsMap,
145 resumable_streams: ResumableStreamsMap,
146 state_tx: tokio::sync::watch::Sender<ConnectionState>,
147 config: TitanConfig,
148 ) {
149 let mut ws_stream = initial_ws_stream;
150 let mut reconnect_attempt: u32 = 0;
151 let mut request_id_counter: u32 = 1;
152
153 loop {
154 let disconnect_reason = Self::run_single_connection(
156 &mut ws_stream,
157 &mut request_rx,
158 &pending_requests,
159 &resumable_streams,
160 &state_tx,
161 &mut request_id_counter,
162 )
163 .await;
164
165 if request_rx.is_closed() {
167 tracing::info!("Request channel closed, shutting down connection");
168 break;
169 }
170
171 reconnect_attempt += 1;
173
174 if let Some(max) = config.max_reconnect_attempts {
176 if reconnect_attempt > max {
177 tracing::error!("Max reconnect attempts ({}) reached, giving up", max);
178 let _ = state_tx.send(ConnectionState::Disconnected {
179 reason: format!(
180 "Max reconnect attempts reached. Last error: {}",
181 disconnect_reason
182 ),
183 });
184 break;
185 }
186 }
187
188 let backoff_ms = calculate_backoff(reconnect_attempt, config.max_reconnect_delay_ms);
190
191 tracing::info!(
192 attempt = reconnect_attempt,
193 backoff_ms,
194 "Reconnecting after disconnection: {}",
195 disconnect_reason
196 );
197
198 let _ = state_tx.send(ConnectionState::Reconnecting {
199 attempt: reconnect_attempt,
200 });
201
202 tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
204
205 match Self::establish_connection(&config).await {
207 Ok(new_stream) => {
208 ws_stream = new_stream;
209 reconnect_attempt = 0;
210 let _ = state_tx.send(ConnectionState::Connected);
211 tracing::info!("Reconnected successfully");
212
213 Self::resume_streams(
215 &mut ws_stream,
216 &resumable_streams,
217 &mut request_id_counter,
218 )
219 .await;
220 }
221 Err(e) => {
222 tracing::warn!("Reconnection failed: {}", e);
223 continue;
224 }
225 }
226 }
227
228 Self::cleanup_pending_requests(&pending_requests).await;
230 }
231
232 async fn resume_streams(
234 ws_stream: &mut WsStream,
235 resumable_streams: &ResumableStreamsMap,
236 request_id_counter: &mut u32,
237 ) {
238 let streams_to_resume: Vec<(u32, ResumableStream)> = {
239 let streams = resumable_streams.read().await;
240 streams.iter().map(|(k, v)| (*k, v.clone())).collect()
241 };
242
243 if streams_to_resume.is_empty() {
244 return;
245 }
246
247 tracing::info!(
248 "Resuming {} streams after reconnection",
249 streams_to_resume.len()
250 );
251
252 let codec = ClientCodec::Uncompressed;
253 let mut encoder = codec.encoder();
254 let mut decoder = codec.decoder();
255
256 for (old_stream_id, resumable) in streams_to_resume {
257 let request_id = *request_id_counter;
258 *request_id_counter += 1;
259
260 let request = ClientRequest {
261 id: request_id,
262 data: RequestData::NewSwapQuoteStream(resumable.request.clone()),
263 };
264
265 let encoded = match encoder.encode_mut(&request) {
267 Ok(data) => data.to_vec(),
268 Err(e) => {
269 tracing::error!("Failed to encode stream resume request: {}", e);
270 continue;
271 }
272 };
273
274 if let Err(e) = ws_stream.send(Message::Binary(encoded.into())).await {
275 tracing::error!("Failed to send stream resume request: {}", e);
276 continue;
277 }
278
279 match ws_stream.next().await {
281 Some(Ok(Message::Binary(data))) => {
282 match decoder.decode_mut(data) {
283 Ok(ServerMessage::Response(response)) => {
284 if let Some(stream_info) = response.stream {
285 let new_stream_id = stream_info.id;
286
287 let mut streams = resumable_streams.write().await;
289 if let Some(stream) = streams.remove(&old_stream_id) {
290 streams.insert(new_stream_id, stream);
291 tracing::info!(
292 old_id = old_stream_id,
293 new_id = new_stream_id,
294 "Stream resumed with new ID"
295 );
296 }
297 }
298 }
299 Ok(ServerMessage::Error(error)) => {
300 tracing::error!(
301 "Failed to resume stream {}: {}",
302 old_stream_id,
303 error.message
304 );
305 let mut streams = resumable_streams.write().await;
307 streams.remove(&old_stream_id);
308 }
309 Ok(_) => {
310 tracing::warn!("Unexpected response type during stream resumption");
311 }
312 Err(e) => {
313 tracing::error!("Failed to decode stream resume response: {}", e);
314 }
315 }
316 }
317 Some(Ok(_)) => {
318 tracing::warn!("Unexpected message type during stream resumption");
319 }
320 Some(Err(e)) => {
321 tracing::error!("WebSocket error during stream resumption: {}", e);
322 break;
323 }
324 None => {
325 tracing::error!("Connection closed during stream resumption");
326 break;
327 }
328 }
329 }
330 }
331
332 async fn run_single_connection(
334 ws_stream: &mut WsStream,
335 request_rx: &mut mpsc::Receiver<PendingRequest>,
336 pending_requests: &PendingRequestsMap,
337 resumable_streams: &ResumableStreamsMap,
338 state_tx: &tokio::sync::watch::Sender<ConnectionState>,
339 request_id_counter: &mut u32,
340 ) -> String {
341 let codec = ClientCodec::Uncompressed;
342 let mut encoder = codec.encoder();
343 let mut decoder = codec.decoder();
344
345 let (mut ws_sink, mut ws_stream_rx) = ws_stream.split();
346
347 loop {
348 tokio::select! {
349 Some(pending_req) = request_rx.recv() => {
350 let request_id = pending_req.request.id;
351 *request_id_counter = request_id.max(*request_id_counter) + 1;
352
353 {
354 let mut pending_map = pending_requests.write().await;
355 pending_map.insert(request_id, pending_req.response_tx);
356 }
357
358 match encoder.encode_mut(&pending_req.request) {
359 Ok(data) => {
360 if let Err(e) = ws_sink.send(Message::Binary(data.to_vec().into())).await {
361 tracing::error!("Failed to send WebSocket message: {}", e);
362 let mut pending_map = pending_requests.write().await;
363 if let Some(tx) = pending_map.remove(&request_id) {
364 let _ = tx.send(Err(ResponseError {
365 request_id,
366 code: 0,
367 message: format!("Send failed: {}", e),
368 }));
369 }
370 }
371 }
372 Err(e) => {
373 tracing::error!("Failed to encode request: {}", e);
374 let mut pending_map = pending_requests.write().await;
375 if let Some(tx) = pending_map.remove(&request_id) {
376 let _ = tx.send(Err(ResponseError {
377 request_id,
378 code: 0,
379 message: format!("Encode failed: {}", e),
380 }));
381 }
382 }
383 }
384 }
385
386 Some(msg_result) = ws_stream_rx.next() => {
387 match msg_result {
388 Ok(Message::Binary(data)) => {
389 match decoder.decode_mut(data) {
390 Ok(server_msg) => {
391 Self::handle_server_message(
392 server_msg,
393 pending_requests,
394 resumable_streams,
395 ).await;
396 }
397 Err(e) => {
398 tracing::error!("Failed to decode server message: {}", e);
399 }
400 }
401 }
402 Ok(Message::Close(frame)) => {
403 let reason = frame
404 .map(|f| f.reason.to_string())
405 .unwrap_or_else(|| "Server closed connection".to_string());
406 tracing::warn!("WebSocket closed: {}", reason);
407 let _ = state_tx.send(ConnectionState::Disconnected {
408 reason: reason.clone(),
409 });
410 return reason;
411 }
412 Ok(Message::Ping(data)) => {
413 let _ = ws_sink.send(Message::Pong(data)).await;
414 }
415 Ok(_) => {}
416 Err(e) => {
417 let reason = format!("WebSocket error: {}", e);
418 tracing::error!("{}", reason);
419 let _ = state_tx.send(ConnectionState::Disconnected {
420 reason: reason.clone(),
421 });
422 return reason;
423 }
424 }
425 }
426
427 else => {
428 return "Channel closed".to_string();
429 }
430 }
431 }
432 }
433
434 async fn handle_server_message(
436 msg: ServerMessage,
437 pending_requests: &PendingRequestsMap,
438 resumable_streams: &ResumableStreamsMap,
439 ) {
440 match msg {
441 ServerMessage::Response(response) => {
442 let mut pending = pending_requests.write().await;
443 if let Some(tx) = pending.remove(&response.request_id) {
444 let _ = tx.send(Ok(response));
445 }
446 }
447 ServerMessage::Error(error) => {
448 let mut pending = pending_requests.write().await;
449 if let Some(tx) = pending.remove(&error.request_id) {
450 let _ = tx.send(Err(error));
451 }
452 }
453 ServerMessage::StreamData(data) => {
454 let streams = resumable_streams.read().await;
455 if let Some(stream) = streams.get(&data.id) {
456 let _ = stream.sender.send(data).await;
457 }
458 }
459 ServerMessage::StreamEnd(end) => {
460 let mut streams = resumable_streams.write().await;
461 streams.remove(&end.id);
462 }
463 ServerMessage::Other(_) => {
464 tracing::warn!("Received unknown server message type");
465 }
466 }
467 }
468
469 async fn cleanup_pending_requests(pending_requests: &PendingRequestsMap) {
471 let mut pending_map = pending_requests.write().await;
472 for (request_id, tx) in pending_map.drain() {
473 let _ = tx.send(Err(ResponseError {
474 request_id,
475 code: 0,
476 message: "Connection closed".to_string(),
477 }));
478 }
479 }
480
481 #[tracing::instrument(skip_all)]
483 pub async fn send_request(
484 &self,
485 data: RequestData,
486 ) -> Result<ResponseSuccess, TitanClientError> {
487 let request_id = self.request_id.fetch_add(1, Ordering::SeqCst);
488 let request = ClientRequest {
489 id: request_id,
490 data,
491 };
492
493 let (response_tx, response_rx) = oneshot::channel();
494
495 self.sender
496 .send(PendingRequest {
497 request,
498 response_tx,
499 })
500 .await
501 .map_err(|_| TitanClientError::Unexpected(anyhow::anyhow!("Connection closed")))?;
502
503 let response = response_rx.await.map_err(|_| {
504 TitanClientError::Unexpected(anyhow::anyhow!("Response channel closed"))
505 })?;
506
507 response.map_err(|e| TitanClientError::ServerError {
508 code: e.code,
509 message: e.message,
510 })
511 }
512
513 pub async fn register_stream(
515 &self,
516 stream_id: u32,
517 request: SwapQuoteRequest,
518 sender: mpsc::Sender<StreamData>,
519 ) {
520 let mut streams = self.resumable_streams.write().await;
521 streams.insert(stream_id, ResumableStream { request, sender });
522 }
523
524 pub async fn unregister_stream(&self, stream_id: u32) {
526 let mut streams = self.resumable_streams.write().await;
527 streams.remove(&stream_id);
528 }
529
530 pub fn state_receiver(&self) -> tokio::sync::watch::Receiver<ConnectionState> {
532 self.state_tx.subscribe()
533 }
534
535 pub fn state(&self) -> ConnectionState {
537 self.state_tx.borrow().clone()
538 }
539
540 pub async fn active_stream_ids(&self) -> Vec<u32> {
542 let streams = self.resumable_streams.read().await;
543 streams.keys().copied().collect()
544 }
545
546 #[tracing::instrument(skip_all)]
550 pub async fn stop_all_streams(&self) {
551 use titan_api_types::ws::v1::StopStreamRequest;
552
553 let stream_ids = self.active_stream_ids().await;
554
555 if stream_ids.is_empty() {
556 return;
557 }
558
559 tracing::info!("Stopping {} active streams", stream_ids.len());
560
561 for stream_id in stream_ids {
562 let _ = self
564 .send_request(RequestData::StopStream(StopStreamRequest { id: stream_id }))
565 .await;
566 }
567
568 let mut streams = self.resumable_streams.write().await;
570 streams.clear();
571 }
572
573 #[tracing::instrument(skip_all)]
575 pub async fn shutdown(&self) {
576 self.stop_all_streams().await;
578
579 let _ = self.state_tx.send(ConnectionState::Disconnected {
581 reason: "Client shutdown".to_string(),
582 });
583
584 }
587}
588
589fn calculate_backoff(attempt: u32, max_delay_ms: u64) -> u64 {
591 let base_delay = INITIAL_BACKOFF_MS * 2u64.saturating_pow(attempt.saturating_sub(1));
592 base_delay.min(max_delay_ms)
593}