titan_rust_client/
client.rs

1//! Main Titan client implementation.
2
3use std::sync::Arc;
4
5use titan_api_types::ws::v1::{
6    GetInfoRequest, GetVenuesRequest, ListProvidersRequest, ProviderInfo, RequestData,
7    ResponseData, ServerInfo, SwapPrice, SwapPriceRequest, SwapQuoteRequest, VenueInfo,
8};
9use tokio::sync::RwLock;
10
11use crate::config::TitanConfig;
12use crate::connection::Connection;
13use crate::error::TitanClientError;
14use crate::queue::StreamManager;
15use crate::state::ConnectionState;
16use crate::stream::QuoteStream;
17
18/// Default max concurrent streams if server doesn't specify.
19const DEFAULT_MAX_CONCURRENT_STREAMS: u32 = 10;
20
21/// Titan Exchange WebSocket client.
22///
23/// Thread-safe client for interacting with the Titan Exchange API.
24/// Can be shared across axum handlers via `Arc<TitanClient>`.
25pub struct TitanClient {
26    connection: Arc<RwLock<Option<Arc<Connection>>>>,
27    stream_manager: Arc<RwLock<Option<Arc<StreamManager>>>>,
28    #[allow(dead_code)]
29    config: TitanConfig,
30}
31
32impl TitanClient {
33    /// Create a new client with the given configuration.
34    ///
35    /// Connects eagerly and fetches server info to determine stream limits.
36    #[tracing::instrument(skip_all)]
37    pub async fn new(config: TitanConfig) -> Result<Self, TitanClientError> {
38        let connection = Arc::new(Connection::connect(config.clone()).await?);
39
40        // Fetch server info to get max concurrent streams
41        let max_streams = Self::fetch_max_streams(&connection).await;
42
43        let stream_manager = StreamManager::new(connection.clone(), max_streams);
44
45        Ok(Self {
46            connection: Arc::new(RwLock::new(Some(connection))),
47            stream_manager: Arc::new(RwLock::new(Some(stream_manager))),
48            config,
49        })
50    }
51
52    /// Fetch max concurrent streams from server info.
53    async fn fetch_max_streams(connection: &Arc<Connection>) -> u32 {
54        match connection
55            .send_request(RequestData::GetInfo(GetInfoRequest { dummy: None }))
56            .await
57        {
58            Ok(response) => match response.data {
59                ResponseData::GetInfo(info) => info.settings.connection.concurrent_streams,
60                _ => DEFAULT_MAX_CONCURRENT_STREAMS,
61            },
62            Err(e) => {
63                tracing::warn!("Failed to fetch server info: {}, using default limits", e);
64                DEFAULT_MAX_CONCURRENT_STREAMS
65            }
66        }
67    }
68
69    /// Returns a watch receiver for connection state changes.
70    ///
71    /// Use this to observe state transitions (Connected, Reconnecting, Disconnected).
72    pub async fn state_receiver(&self) -> tokio::sync::watch::Receiver<ConnectionState> {
73        let conn = self.connection.read().await;
74        if let Some(ref connection) = *conn {
75            connection.state_receiver()
76        } else {
77            let (tx, rx) = tokio::sync::watch::channel(ConnectionState::Disconnected {
78                reason: "Not connected".to_string(),
79            });
80            drop(tx);
81            rx
82        }
83    }
84
85    /// Get the current connection state.
86    pub async fn state(&self) -> ConnectionState {
87        let conn = self.connection.read().await;
88        if let Some(ref connection) = *conn {
89            connection.state()
90        } else {
91            ConnectionState::Disconnected {
92                reason: "Not connected".to_string(),
93            }
94        }
95    }
96
97    /// Returns true if currently connected.
98    pub async fn is_connected(&self) -> bool {
99        self.state().await.is_connected()
100    }
101
102    /// Wait until the connection is established.
103    ///
104    /// Returns immediately if already connected.
105    /// Returns error if connection is permanently closed.
106    pub async fn wait_for_connected(&self) -> Result<(), TitanClientError> {
107        let mut receiver = self.state_receiver().await;
108
109        loop {
110            let state = receiver.borrow_and_update().clone();
111            match state {
112                ConnectionState::Connected => return Ok(()),
113                ConnectionState::Disconnected { reason } => {
114                    // Check if this is a permanent disconnection
115                    if reason.contains("Max reconnect attempts") {
116                        return Err(TitanClientError::ConnectionFailed {
117                            attempts: 0,
118                            reason,
119                        });
120                    }
121                }
122                ConnectionState::Reconnecting { .. } => {}
123            }
124
125            // Wait for next state change
126            if receiver.changed().await.is_err() {
127                return Err(TitanClientError::ConnectionFailed {
128                    attempts: 0,
129                    reason: "Connection closed".to_string(),
130                });
131            }
132        }
133    }
134
135    /// Get a clone of the connection Arc.
136    async fn get_connection(&self) -> Result<Arc<Connection>, TitanClientError> {
137        let conn = self.connection.read().await;
138        conn.clone()
139            .ok_or_else(|| TitanClientError::ConnectionFailed {
140                attempts: 0,
141                reason: "Not connected".to_string(),
142            })
143    }
144
145    /// Get a clone of the stream manager Arc.
146    async fn get_stream_manager(&self) -> Result<Arc<StreamManager>, TitanClientError> {
147        let manager = self.stream_manager.read().await;
148        manager
149            .clone()
150            .ok_or_else(|| TitanClientError::ConnectionFailed {
151                attempts: 0,
152                reason: "Not connected".to_string(),
153            })
154    }
155
156    /// Get the current active stream count.
157    pub async fn active_stream_count(&self) -> u32 {
158        match self.get_stream_manager().await {
159            Ok(manager) => manager.active_count(),
160            Err(_) => 0,
161        }
162    }
163
164    /// Get the current queue length.
165    pub async fn queued_stream_count(&self) -> usize {
166        match self.get_stream_manager().await {
167            Ok(manager) => manager.queue_len().await,
168            Err(_) => 0,
169        }
170    }
171
172    /// Graceful shutdown: stops all streams, then closes WebSocket.
173    ///
174    /// This method:
175    /// 1. Sends StopStream for all active streams
176    /// 2. Clears the stream manager
177    /// 3. Closes the WebSocket connection
178    ///
179    /// After calling this, the client cannot be reused.
180    #[tracing::instrument(skip_all)]
181    pub async fn close(&self) -> Result<(), TitanClientError> {
182        // First, shutdown the connection (stops all streams)
183        {
184            let conn = self.connection.read().await;
185            if let Some(ref connection) = *conn {
186                connection.shutdown().await;
187            }
188        }
189
190        // Clear stream manager
191        {
192            let mut manager = self.stream_manager.write().await;
193            *manager = None;
194        }
195
196        // Clear connection (this will cause the background loop to exit)
197        {
198            let mut conn = self.connection.write().await;
199            *conn = None;
200        }
201
202        Ok(())
203    }
204
205    /// Check if the client has been closed.
206    pub async fn is_closed(&self) -> bool {
207        let conn = self.connection.read().await;
208        conn.is_none()
209    }
210
211    // ========== One-shot API methods ==========
212
213    /// Get server info and connection limits.
214    #[tracing::instrument(skip_all)]
215    pub async fn get_info(&self) -> Result<ServerInfo, TitanClientError> {
216        let connection = self.get_connection().await?;
217        let response = connection
218            .send_request(RequestData::GetInfo(GetInfoRequest { dummy: None }))
219            .await?;
220
221        match response.data {
222            ResponseData::GetInfo(info) => Ok(info),
223            other => Err(TitanClientError::Unexpected(anyhow::anyhow!(
224                "Unexpected response type: expected GetInfo, got {:?}",
225                std::mem::discriminant(&other)
226            ))),
227        }
228    }
229
230    /// Get available venues.
231    #[tracing::instrument(skip_all)]
232    pub async fn get_venues(&self) -> Result<VenueInfo, TitanClientError> {
233        let connection = self.get_connection().await?;
234        let response = connection
235            .send_request(RequestData::GetVenues(GetVenuesRequest {
236                include_program_ids: Some(true),
237            }))
238            .await?;
239
240        match response.data {
241            ResponseData::GetVenues(venues) => Ok(venues),
242            other => Err(TitanClientError::Unexpected(anyhow::anyhow!(
243                "Unexpected response type: expected GetVenues, got {:?}",
244                std::mem::discriminant(&other)
245            ))),
246        }
247    }
248
249    /// List available providers.
250    #[tracing::instrument(skip_all)]
251    pub async fn list_providers(&self) -> Result<Vec<ProviderInfo>, TitanClientError> {
252        let connection = self.get_connection().await?;
253        let response = connection
254            .send_request(RequestData::ListProviders(ListProvidersRequest {
255                include_icons: Some(true),
256            }))
257            .await?;
258
259        match response.data {
260            ResponseData::ListProviders(providers) => Ok(providers),
261            other => Err(TitanClientError::Unexpected(anyhow::anyhow!(
262                "Unexpected response type: expected ListProviders, got {:?}",
263                std::mem::discriminant(&other)
264            ))),
265        }
266    }
267
268    /// Get a point-in-time swap price.
269    #[tracing::instrument(skip_all)]
270    pub async fn get_swap_price(
271        &self,
272        request: SwapPriceRequest,
273    ) -> Result<SwapPrice, TitanClientError> {
274        let connection = self.get_connection().await?;
275        let response = connection
276            .send_request(RequestData::GetSwapPrice(request))
277            .await?;
278
279        match response.data {
280            ResponseData::GetSwapPrice(price) => Ok(price),
281            other => Err(TitanClientError::Unexpected(anyhow::anyhow!(
282                "Unexpected response type: expected GetSwapPrice, got {:?}",
283                std::mem::discriminant(&other)
284            ))),
285        }
286    }
287
288    // ========== Streaming API methods ==========
289
290    /// Start a new swap quote stream.
291    ///
292    /// Returns a `QuoteStream` that yields `SwapQuotes` updates.
293    /// The stream automatically sends `StopStream` when dropped.
294    ///
295    /// If the server's max concurrent streams limit is reached, the request
296    /// will be queued and started automatically when a slot becomes available.
297    #[tracing::instrument(skip_all)]
298    pub async fn new_swap_quote_stream(
299        &self,
300        request: SwapQuoteRequest,
301    ) -> Result<QuoteStream, TitanClientError> {
302        let manager = self.get_stream_manager().await?;
303        manager.request_stream(request).await
304    }
305}