solana_trader_client_rust/connections/
ws.rs1use anyhow::Result;
2use futures_util::{SinkExt, StreamExt};
3use rustls::crypto::ring::default_provider;
4use rustls::crypto::CryptoProvider;
5use serde::de::DeserializeOwned;
6use serde::Serialize;
7use serde_json::{json, Value};
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::net::TcpStream;
13use tokio::sync::mpsc::Sender;
14use tokio::sync::{broadcast, mpsc, Mutex};
15use tokio::time::timeout;
16use tokio_rustls::rustls::{ClientConfig, RootCertStore};
17use tokio_stream::wrappers::ReceiverStream;
18use tokio_stream::Stream;
19use tokio_tungstenite::tungstenite::client::IntoClientRequest;
20use tokio_tungstenite::tungstenite::handshake::client::Request;
21use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
22use tokio_tungstenite::{connect_async_tls_with_config, Connector};
23use tokio_tungstenite::{tungstenite::protocol::Message, WebSocketStream};
24use url::Url;
25
26use crate::common::constants::WARNING_TLS_SLOWDOWN;
27use crate::common::{get_base_url_from_env, ws_endpoint, BaseConfig};
28use crate::provider::utils::convert_string_enums;
29
30const CONNECTION_RETRY_TIMEOUT: Duration = Duration::from_secs(15);
31const CONNECTION_RETRY_INTERVAL: Duration = Duration::from_millis(100);
32const SUBSCRIPTION_BUFFER: usize = 1000;
33const PING_INTERVAL: Duration = Duration::from_secs(15);
34
35#[derive(Debug)]
36pub struct Subscription {
37 sender: mpsc::Sender<Value>,
38}
39
40#[derive(Clone)]
41struct ResponseUpdate {
42 response: String,
43}
44
45struct RequestTracker {
46 ch: Sender<ResponseUpdate>,
47}
48
49pub struct WS {
50 stream: Arc<Mutex<WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>>>,
51 write_tx: Sender<Message>,
52 shutdown_tx: broadcast::Sender<()>,
53 request_id: AtomicU64,
54 request_map: Arc<Mutex<HashMap<u64, RequestTracker>>>,
55 subscriptions: Arc<Mutex<HashMap<String, Subscription>>>,
56}
57
58impl WS {
59 pub async fn new(endpoint: Option<String>) -> Result<Self> {
60 let base = BaseConfig::try_from_env()?;
61 let (base_url, secure) = get_base_url_from_env();
62 let endpoint = endpoint.unwrap_or_else(|| ws_endpoint(&base_url, secure));
63 if endpoint.starts_with("wss://") {
64 println!("{}", WARNING_TLS_SLOWDOWN);
65 }
66
67 if base.auth_header.is_empty() {
68 return Err(anyhow::anyhow!("AUTH_HEADER is empty"));
69 }
70
71 let url =
72 Url::parse(&endpoint).map_err(|e| anyhow::anyhow!("Invalid WebSocket URL: {}", e))?;
73
74 let stream = Self::connect(&url, &base.auth_header).await?;
75 let stream = Arc::new(Mutex::new(stream));
76
77 let (write_tx, write_rx) = mpsc::channel(100);
78 let (shutdown_tx, _) = broadcast::channel(1);
79
80 let ws = Self {
81 stream: stream.clone(),
82 write_tx,
83 shutdown_tx,
84 request_id: AtomicU64::new(0),
85 request_map: Arc::new(Mutex::new(HashMap::new())),
86 subscriptions: Arc::new(Mutex::new(HashMap::new())),
87 };
88
89 ws.start_loops(stream, write_rx);
90 Ok(ws)
91 }
92
93 async fn connect(
94 url: &Url,
95 auth_header: &str,
96 ) -> Result<WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>> {
97 let request = Self::build_request(url, auth_header)?;
98
99 let mut retry_count = 0;
100 let max_retries =
101 (CONNECTION_RETRY_TIMEOUT.as_millis() / CONNECTION_RETRY_INTERVAL.as_millis()) as u32;
102
103 loop {
104 match connect_async_tls_with_config(
105 request.clone(),
106 Some(WebSocketConfig::default()),
107 true,
108 Some(Connector::Rustls(Self::setup_tls()?)),
109 )
110 .await
111 {
112 Ok((stream, _)) => {
113 println!("Connected to: {}", url);
114 return Ok(stream);
115 }
116 Err(e) => {
117 if retry_count >= max_retries {
118 return Err(anyhow::anyhow!(
119 "WebSocket connection failed after {} retries: {}",
120 max_retries,
121 e
122 ));
123 }
124 retry_count += 1;
125 tokio::time::sleep(CONNECTION_RETRY_INTERVAL).await;
126 }
127 }
128 }
129 }
130
131 fn build_request(url: &Url, auth_header: &str) -> Result<Request> {
132 let mut request = url
133 .as_str()
134 .into_client_request()
135 .map_err(|e| anyhow::anyhow!("Failed to build request: {}", e))?;
136
137 let headers = request.headers_mut();
138 headers.insert("Authorization", auth_header.parse()?);
139 headers.insert("x-sdk", "rust-client".parse()?);
140 headers.insert("x-sdk-version", env!("CARGO_PKG_VERSION").parse()?);
141 headers.insert("Connection", "Upgrade".parse()?);
142 headers.insert("Upgrade", "websocket".parse()?);
143 headers.insert("Sec-WebSocket-Version", "13".parse()?);
144
145 Ok(request)
146 }
147
148 fn setup_tls() -> Result<Arc<ClientConfig>> {
149 if CryptoProvider::get_default().is_none() {
150 default_provider()
151 .install_default()
152 .map_err(|e| anyhow::anyhow!("Failed to install crypto provider: {:?}", e))?;
153 }
154
155 let root_store = RootCertStore {
156 roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
157 };
158
159 let tls_config = ClientConfig::builder()
160 .with_root_certificates(root_store)
161 .with_no_client_auth();
162
163 Ok(Arc::new(tls_config))
164 }
165
166 fn start_loops(
167 &self,
168 stream: Arc<Mutex<WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>>>,
169 write_rx: mpsc::Receiver<Message>,
170 ) {
171 let write_stream = stream.clone();
172 tokio::spawn(write_loop(write_stream, write_rx));
173
174 let read_stream = stream.clone();
175 let request_map = self.request_map.clone();
176 let subscriptions = self.subscriptions.clone();
177 tokio::spawn(read_loop(read_stream, request_map, subscriptions));
178
179 let ping_stream = stream;
180 let shutdown_rx = self.shutdown_tx.subscribe();
181 tokio::spawn(ping_loop(ping_stream, shutdown_rx));
182 }
183
184 pub async fn request<T>(&self, method: &str, params: Value) -> Result<T>
185 where
186 T: DeserializeOwned,
187 {
188 let request_id = self.request_id.fetch_add(1, Ordering::SeqCst);
189 let request_json = json!({
190 "jsonrpc": "2.0",
191 "id": request_id,
192 "method": method,
193 "params": params
194 });
195
196 let (tx, mut rx) = mpsc::channel(1);
197 {
198 let mut request_map = self.request_map.lock().await;
199 request_map.insert(request_id, RequestTracker { ch: tx });
200 }
201
202 let msg = Message::Text(request_json.to_string());
203 timeout(Duration::from_secs(5), self.write_tx.send(msg))
204 .await
205 .map_err(|_| anyhow::anyhow!("Request send timeout"))?
206 .map_err(|e| anyhow::anyhow!("Failed to send request: {}", e))?;
207
208 let response = timeout(Duration::from_secs(10), rx.recv())
209 .await
210 .map_err(|_| anyhow::anyhow!("Response timeout"))?
211 .ok_or_else(|| anyhow::anyhow!("Channel closed unexpectedly"))?;
212
213 let json_response: Value = serde_json::from_str(&response.response)
214 .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
215
216 if let Some(error) = json_response.get("error") {
217 return Err(anyhow::anyhow!("RPC error: {}", error));
218 }
219
220 let result = json_response
221 .get("result")
222 .ok_or_else(|| anyhow::anyhow!("Missing result field in response"))?;
223
224 let mut res = result.clone();
225 convert_string_enums(&mut res);
226
227 serde_json::from_value(res).map_err(|e| anyhow::anyhow!("Failed to parse result: {}", e))
228 }
229
230 pub async fn stream_proto<Req, Resp>(
231 &self,
232 method: &str,
233 request: &Req,
234 ) -> Result<impl Stream<Item = Result<Resp>> + Unpin>
235 where
236 Req: prost::Message + Serialize,
237 Resp: prost::Message + Default + DeserializeOwned + Send + Clone + 'static,
238 {
239 let (tx, rx) = mpsc::channel(SUBSCRIPTION_BUFFER);
240
241 let params = serde_json::to_value(request)?;
242 let params_array = json!([method, params]);
243 let subscription_id: String = self.request("subscribe", params_array).await?;
244
245 {
246 let mut subs = self.subscriptions.lock().await;
247 subs.insert(subscription_id, Subscription { sender: tx });
248 }
249
250 Ok(ReceiverStream::new(rx).map(|value: Value| {
251 let mut modified_value = value;
252 convert_string_enums(&mut modified_value);
253
254 serde_json::from_value(modified_value)
255 .map_err(|e| anyhow::anyhow!("Failed to parse stream value: {}", e))
256 }))
257 }
258
259 pub async fn close(self) -> Result<()> {
260 let _ = self.shutdown_tx.send(());
261
262 {
263 let mut request_map = self.request_map.lock().await;
264 for (_, tracker) in request_map.drain() {
265 let _ = tracker
266 .ch
267 .send(ResponseUpdate {
268 response: String::from("{\"error\":\"connection closed\"}"),
269 })
270 .await;
271 }
272 }
273
274 let mut stream = self.stream.lock().await;
275 if let Err(e) = stream.close(None).await {
276 eprintln!("Error during WebSocket close: {}", e);
277 }
278 drop(stream);
279
280 tokio::time::sleep(Duration::from_millis(100)).await;
281 println!("WebSocket shutdown complete");
282 Ok(())
283 }
284}
285
286async fn write_loop(
287 stream: Arc<Mutex<WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>>>,
288 mut write_rx: mpsc::Receiver<Message>,
289) {
290 while let Some(msg) = write_rx.recv().await {
291 let mut stream = stream.lock().await;
292 if let Err(e) = stream.send(msg).await {
293 eprintln!("Write error: {}", e);
294 break;
295 }
296 }
297}
298
299async fn read_loop(
300 stream: Arc<Mutex<WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>>>,
301 request_map: Arc<Mutex<HashMap<u64, RequestTracker>>>,
302 subscriptions: Arc<Mutex<HashMap<String, Subscription>>>,
303) {
304 loop {
305 let mut stream = stream.lock().await;
306 let Ok(Some(Ok(msg))) = timeout(Duration::from_millis(100), stream.next()).await else {
307 continue;
308 };
309
310 match msg {
311 Message::Text(text) => {
312 if let Ok(value) = serde_json::from_str(&text) {
313 handle_message(&value, &request_map, &subscriptions, &text).await;
314 }
315 }
316 Message::Close(_) => break,
317 _ => (),
318 }
319 }
320}
321
322async fn ping_loop(
323 stream: Arc<Mutex<WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>>>,
324 mut shutdown_rx: broadcast::Receiver<()>,
325) {
326 let mut interval = tokio::time::interval(PING_INTERVAL);
327 loop {
328 tokio::select! {
329 _ = interval.tick() => {
330 let mut stream = stream.lock().await;
331 if let Err(e) = stream.send(Message::Ping(vec![])).await {
332 eprintln!("Ping error: {}", e);
333 break;
334 }
335 }
336 Ok(_) = shutdown_rx.recv() => break,
337 }
338 }
339}
340
341async fn handle_message(
342 value: &Value,
343 request_map: &Arc<Mutex<HashMap<u64, RequestTracker>>>,
344 subscriptions: &Arc<Mutex<HashMap<String, Subscription>>>,
345 text: &str,
346) {
347 match value.get("id").and_then(|id| id.as_u64()) {
348 Some(id) => {
349 if let Some(tracker) = request_map.lock().await.get(&id) {
350 let _ = tracker
351 .ch
352 .send(ResponseUpdate {
353 response: text.to_string(),
354 })
355 .await;
356 }
357 }
358 None => handle_subscription(value, subscriptions).await,
359 }
360}
361
362async fn handle_subscription(
363 map: &Value,
364 subscriptions: &Arc<Mutex<HashMap<String, Subscription>>>,
365) {
366 let Some(id) = map
367 .get("params")
368 .and_then(|p| p.get("subscription"))
369 .and_then(|s| s.as_str())
370 else {
371 return;
372 };
373
374 let Some(result) = map.get("params").and_then(|p| p.get("result")) else {
375 return;
376 };
377
378 if let Some(sub) = subscriptions.lock().await.get(id) {
379 let _ = sub.sender.send(result.clone()).await;
380 }
381}