solana_trader_client_rust/connections/
ws.rs1use anyhow::Result;
2use futures_util::{SinkExt, StreamExt};
3use rustls::crypto::ring::default_provider;
4use rustls::crypto::CryptoProvider;
5use serde::de::DeserializeOwned;
6use serde::Serialize;
7use serde_json::{json, Value};
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::net::TcpStream;
13use tokio::sync::mpsc::Sender;
14use tokio::sync::{broadcast, mpsc, Mutex};
15use tokio::time::timeout;
16use tokio_rustls::rustls::{ClientConfig, RootCertStore};
17use tokio_stream::wrappers::ReceiverStream;
18use tokio_stream::Stream;
19use tokio_tungstenite::tungstenite::client::IntoClientRequest;
20use tokio_tungstenite::tungstenite::handshake::client::Request;
21use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
22use tokio_tungstenite::{connect_async_tls_with_config, Connector};
23use tokio_tungstenite::{tungstenite::protocol::Message, WebSocketStream};
24use url::Url;
25
26use crate::common::{get_base_url_from_env, ws_endpoint, BaseConfig};
27use crate::provider::utils::convert_string_enums;
28
29const CONNECTION_RETRY_TIMEOUT: Duration = Duration::from_secs(15);
30const CONNECTION_RETRY_INTERVAL: Duration = Duration::from_millis(100);
31const SUBSCRIPTION_BUFFER: usize = 1000;
32const PING_INTERVAL: Duration = Duration::from_secs(30);
33
34#[derive(Debug)]
35pub struct Subscription {
36 sender: mpsc::Sender<Value>,
37}
38
39#[derive(Clone)]
40struct ResponseUpdate {
41 response: String,
42}
43
44struct RequestTracker {
45 ch: Sender<ResponseUpdate>,
46}
47
48pub struct WS {
49 stream: Arc<Mutex<WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>>>,
50 write_tx: Sender<Message>,
51 shutdown_tx: broadcast::Sender<()>,
52 request_id: AtomicU64,
53 request_map: Arc<Mutex<HashMap<u64, RequestTracker>>>,
54 subscriptions: Arc<Mutex<HashMap<String, Subscription>>>,
55}
56
57impl WS {
58 pub async fn new(endpoint: Option<String>) -> Result<Self> {
59 let base = BaseConfig::try_from_env()?;
60 let (base_url, secure) = get_base_url_from_env();
61 let endpoint = endpoint.unwrap_or_else(|| ws_endpoint(&base_url, secure));
62
63 if base.auth_header.is_empty() {
64 return Err(anyhow::anyhow!("AUTH_HEADER is empty"));
65 }
66
67 let url =
68 Url::parse(&endpoint).map_err(|e| anyhow::anyhow!("Invalid WebSocket URL: {}", e))?;
69
70 let stream = Self::connect(&url, &base.auth_header).await?;
71 let stream = Arc::new(Mutex::new(stream));
72
73 let (write_tx, write_rx) = mpsc::channel(100);
74 let (shutdown_tx, _) = broadcast::channel(1);
75
76 let ws = Self {
77 stream: stream.clone(),
78 write_tx,
79 shutdown_tx,
80 request_id: AtomicU64::new(0),
81 request_map: Arc::new(Mutex::new(HashMap::new())),
82 subscriptions: Arc::new(Mutex::new(HashMap::new())),
83 };
84
85 ws.start_loops(stream, write_rx);
86 Ok(ws)
87 }
88
89 async fn connect(
90 url: &Url,
91 auth_header: &str,
92 ) -> Result<WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>> {
93 let request = Self::build_request(url, auth_header)?;
94
95 let mut retry_count = 0;
96 let max_retries =
97 (CONNECTION_RETRY_TIMEOUT.as_millis() / CONNECTION_RETRY_INTERVAL.as_millis()) as u32;
98
99 loop {
100 match connect_async_tls_with_config(
101 request.clone(),
102 Some(WebSocketConfig::default()),
103 true,
104 Some(Connector::Rustls(Self::setup_tls()?)),
105 )
106 .await
107 {
108 Ok((stream, _)) => {
109 println!("Connected to: {}", url);
110 return Ok(stream);
111 }
112 Err(e) => {
113 if retry_count >= max_retries {
114 return Err(anyhow::anyhow!(
115 "WebSocket connection failed after {} retries: {}",
116 max_retries,
117 e
118 ));
119 }
120 retry_count += 1;
121 tokio::time::sleep(CONNECTION_RETRY_INTERVAL).await;
122 }
123 }
124 }
125 }
126
127 fn build_request(url: &Url, auth_header: &str) -> Result<Request> {
128 let mut request = url
129 .as_str()
130 .into_client_request()
131 .map_err(|e| anyhow::anyhow!("Failed to build request: {}", e))?;
132
133 let headers = request.headers_mut();
134 headers.insert("Authorization", auth_header.parse()?);
135 headers.insert("x-sdk", "rust-client".parse()?);
136 headers.insert("x-sdk-version", env!("CARGO_PKG_VERSION").parse()?);
137 headers.insert("Connection", "Upgrade".parse()?);
138 headers.insert("Upgrade", "websocket".parse()?);
139 headers.insert("Sec-WebSocket-Version", "13".parse()?);
140
141 Ok(request)
142 }
143
144 fn setup_tls() -> Result<Arc<ClientConfig>> {
145 if CryptoProvider::get_default().is_none() {
146 default_provider()
147 .install_default()
148 .map_err(|e| anyhow::anyhow!("Failed to install crypto provider: {:?}", e))?;
149 }
150
151 let root_store = RootCertStore {
152 roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
153 };
154
155 let tls_config = ClientConfig::builder()
156 .with_root_certificates(root_store)
157 .with_no_client_auth();
158
159 Ok(Arc::new(tls_config))
160 }
161
162 fn start_loops(
163 &self,
164 stream: Arc<Mutex<WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>>>,
165 write_rx: mpsc::Receiver<Message>,
166 ) {
167 let write_stream = stream.clone();
168 tokio::spawn(write_loop(write_stream, write_rx));
169
170 let read_stream = stream.clone();
171 let request_map = self.request_map.clone();
172 let subscriptions = self.subscriptions.clone();
173 tokio::spawn(read_loop(read_stream, request_map, subscriptions));
174
175 let ping_stream = stream;
176 let shutdown_rx = self.shutdown_tx.subscribe();
177 tokio::spawn(ping_loop(ping_stream, shutdown_rx));
178 }
179
180 pub async fn request<T>(&self, method: &str, params: Value) -> Result<T>
181 where
182 T: DeserializeOwned,
183 {
184 let request_id = self.request_id.fetch_add(1, Ordering::SeqCst);
185 let request_json = json!({
186 "jsonrpc": "2.0",
187 "id": request_id,
188 "method": method,
189 "params": params
190 });
191
192 let (tx, mut rx) = mpsc::channel(1);
193 {
194 let mut request_map = self.request_map.lock().await;
195 request_map.insert(request_id, RequestTracker { ch: tx });
196 }
197
198 let msg = Message::Text(request_json.to_string());
199 timeout(Duration::from_secs(5), self.write_tx.send(msg))
200 .await
201 .map_err(|_| anyhow::anyhow!("Request send timeout"))?
202 .map_err(|e| anyhow::anyhow!("Failed to send request: {}", e))?;
203
204 let response = timeout(Duration::from_secs(10), rx.recv())
205 .await
206 .map_err(|_| anyhow::anyhow!("Response timeout"))?
207 .ok_or_else(|| anyhow::anyhow!("Channel closed unexpectedly"))?;
208
209 let json_response: Value = serde_json::from_str(&response.response)
210 .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
211
212 if let Some(error) = json_response.get("error") {
213 return Err(anyhow::anyhow!("RPC error: {}", error));
214 }
215
216 let result = json_response
217 .get("result")
218 .ok_or_else(|| anyhow::anyhow!("Missing result field in response"))?;
219
220 let mut res = result.clone();
221 convert_string_enums(&mut res);
222
223 serde_json::from_value(res).map_err(|e| anyhow::anyhow!("Failed to parse result: {}", e))
224 }
225
226 pub async fn stream_proto<Req, Resp>(
227 &self,
228 method: &str,
229 request: &Req,
230 ) -> Result<impl Stream<Item = Result<Resp>> + Unpin>
231 where
232 Req: prost::Message + Serialize,
233 Resp: prost::Message + Default + DeserializeOwned + Send + Clone + 'static,
234 {
235 let (tx, rx) = mpsc::channel(SUBSCRIPTION_BUFFER);
236
237 let params = serde_json::to_value(request)?;
238 let params_array = json!([method, params]);
239 let subscription_id: String = self.request("subscribe", params_array).await?;
240
241 {
242 let mut subs = self.subscriptions.lock().await;
243 subs.insert(subscription_id, Subscription { sender: tx });
244 }
245
246 Ok(ReceiverStream::new(rx).map(|value: Value| {
247 let mut modified_value = value;
248 convert_string_enums(&mut modified_value);
249
250 serde_json::from_value(modified_value)
251 .map_err(|e| anyhow::anyhow!("Failed to parse stream value: {}", e))
252 }))
253 }
254
255 pub async fn close(self) -> Result<()> {
256 let _ = self.shutdown_tx.send(());
257
258 {
259 let mut request_map = self.request_map.lock().await;
260 for (_, tracker) in request_map.drain() {
261 let _ = tracker
262 .ch
263 .send(ResponseUpdate {
264 response: String::from("{\"error\":\"connection closed\"}"),
265 })
266 .await;
267 }
268 }
269
270 let mut stream = self.stream.lock().await;
271 if let Err(e) = stream.close(None).await {
272 eprintln!("Error during WebSocket close: {}", e);
273 }
274 drop(stream);
275
276 tokio::time::sleep(Duration::from_millis(100)).await;
277 println!("WebSocket shutdown complete");
278 Ok(())
279 }
280}
281
282async fn write_loop(
283 stream: Arc<Mutex<WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>>>,
284 mut write_rx: mpsc::Receiver<Message>,
285) {
286 while let Some(msg) = write_rx.recv().await {
287 let mut stream = stream.lock().await;
288 if let Err(e) = stream.send(msg).await {
289 eprintln!("Write error: {}", e);
290 break;
291 }
292 }
293}
294
295async fn read_loop(
296 stream: Arc<Mutex<WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>>>,
297 request_map: Arc<Mutex<HashMap<u64, RequestTracker>>>,
298 subscriptions: Arc<Mutex<HashMap<String, Subscription>>>,
299) {
300 loop {
301 let mut stream = stream.lock().await;
302 let Ok(Some(Ok(msg))) = timeout(Duration::from_millis(100), stream.next()).await else {
303 continue;
304 };
305
306 match msg {
307 Message::Text(text) => {
308 if let Ok(value) = serde_json::from_str(&text) {
309 handle_message(&value, &request_map, &subscriptions, &text).await;
310 }
311 }
312 Message::Close(_) => break,
313 _ => (),
314 }
315 }
316}
317
318async fn ping_loop(
319 stream: Arc<Mutex<WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>>>,
320 mut shutdown_rx: broadcast::Receiver<()>,
321) {
322 let mut interval = tokio::time::interval(PING_INTERVAL);
323 loop {
324 tokio::select! {
325 _ = interval.tick() => {
326 let mut stream = stream.lock().await;
327 if let Err(e) = stream.send(Message::Ping(vec![])).await {
328 eprintln!("Ping error: {}", e);
329 break;
330 }
331 }
332 Ok(_) = shutdown_rx.recv() => break,
333 }
334 }
335}
336
337async fn handle_message(
338 value: &Value,
339 request_map: &Arc<Mutex<HashMap<u64, RequestTracker>>>,
340 subscriptions: &Arc<Mutex<HashMap<String, Subscription>>>,
341 text: &str,
342) {
343 match value.get("id").and_then(|id| id.as_u64()) {
344 Some(id) => {
345 if let Some(tracker) = request_map.lock().await.get(&id) {
346 let _ = tracker
347 .ch
348 .send(ResponseUpdate {
349 response: text.to_string(),
350 })
351 .await;
352 }
353 }
354 None => handle_subscription(value, subscriptions).await,
355 }
356}
357
358async fn handle_subscription(
359 map: &Value,
360 subscriptions: &Arc<Mutex<HashMap<String, Subscription>>>,
361) {
362 let Some(id) = map
363 .get("params")
364 .and_then(|p| p.get("subscription"))
365 .and_then(|s| s.as_str())
366 else {
367 return;
368 };
369
370 let Some(result) = map.get("params").and_then(|p| p.get("result")) else {
371 return;
372 };
373
374 if let Some(sub) = subscriptions.lock().await.get(id) {
375 let _ = sub.sender.send(result.clone()).await;
376 }
377}