solana_trader_client_rust/connections/
ws.rs

1use anyhow::Result;
2use futures_util::{SinkExt, StreamExt};
3use rustls::crypto::ring::default_provider;
4use rustls::crypto::CryptoProvider;
5use serde::de::DeserializeOwned;
6use serde::Serialize;
7use serde_json::{json, Value};
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::net::TcpStream;
13use tokio::sync::mpsc::Sender;
14use tokio::sync::{broadcast, mpsc, Mutex};
15use tokio::time::timeout;
16use tokio_rustls::rustls::{ClientConfig, RootCertStore};
17use tokio_stream::wrappers::ReceiverStream;
18use tokio_stream::Stream;
19use tokio_tungstenite::tungstenite::client::IntoClientRequest;
20use tokio_tungstenite::tungstenite::handshake::client::Request;
21use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
22use tokio_tungstenite::{connect_async_tls_with_config, Connector};
23use tokio_tungstenite::{tungstenite::protocol::Message, WebSocketStream};
24use url::Url;
25
26use crate::common::{get_base_url_from_env, ws_endpoint, BaseConfig};
27use crate::provider::utils::convert_string_enums;
28
29const CONNECTION_RETRY_TIMEOUT: Duration = Duration::from_secs(15);
30const CONNECTION_RETRY_INTERVAL: Duration = Duration::from_millis(100);
31const SUBSCRIPTION_BUFFER: usize = 1000;
32const PING_INTERVAL: Duration = Duration::from_secs(30);
33
34#[derive(Debug)]
35pub struct Subscription {
36    sender: mpsc::Sender<Value>,
37}
38
39#[derive(Clone)]
40struct ResponseUpdate {
41    response: String,
42}
43
44struct RequestTracker {
45    ch: Sender<ResponseUpdate>,
46}
47
48pub struct WS {
49    stream: Arc<Mutex<WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>>>,
50    write_tx: Sender<Message>,
51    shutdown_tx: broadcast::Sender<()>,
52    request_id: AtomicU64,
53    request_map: Arc<Mutex<HashMap<u64, RequestTracker>>>,
54    subscriptions: Arc<Mutex<HashMap<String, Subscription>>>,
55}
56
57impl WS {
58    pub async fn new(endpoint: Option<String>) -> Result<Self> {
59        let base = BaseConfig::try_from_env()?;
60        let (base_url, secure) = get_base_url_from_env();
61        let endpoint = endpoint.unwrap_or_else(|| ws_endpoint(&base_url, secure));
62
63        if base.auth_header.is_empty() {
64            return Err(anyhow::anyhow!("AUTH_HEADER is empty"));
65        }
66
67        let url =
68            Url::parse(&endpoint).map_err(|e| anyhow::anyhow!("Invalid WebSocket URL: {}", e))?;
69
70        let stream = Self::connect(&url, &base.auth_header).await?;
71        let stream = Arc::new(Mutex::new(stream));
72
73        let (write_tx, write_rx) = mpsc::channel(100);
74        let (shutdown_tx, _) = broadcast::channel(1);
75
76        let ws = Self {
77            stream: stream.clone(),
78            write_tx,
79            shutdown_tx,
80            request_id: AtomicU64::new(0),
81            request_map: Arc::new(Mutex::new(HashMap::new())),
82            subscriptions: Arc::new(Mutex::new(HashMap::new())),
83        };
84
85        ws.start_loops(stream, write_rx);
86        Ok(ws)
87    }
88
89    async fn connect(
90        url: &Url,
91        auth_header: &str,
92    ) -> Result<WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>> {
93        let request = Self::build_request(url, auth_header)?;
94
95        let mut retry_count = 0;
96        let max_retries =
97            (CONNECTION_RETRY_TIMEOUT.as_millis() / CONNECTION_RETRY_INTERVAL.as_millis()) as u32;
98
99        loop {
100            match connect_async_tls_with_config(
101                request.clone(),
102                Some(WebSocketConfig::default()),
103                true,
104                Some(Connector::Rustls(Self::setup_tls()?)),
105            )
106            .await
107            {
108                Ok((stream, _)) => {
109                    println!("Connected to: {}", url);
110                    return Ok(stream);
111                }
112                Err(e) => {
113                    if retry_count >= max_retries {
114                        return Err(anyhow::anyhow!(
115                            "WebSocket connection failed after {} retries: {}",
116                            max_retries,
117                            e
118                        ));
119                    }
120                    retry_count += 1;
121                    tokio::time::sleep(CONNECTION_RETRY_INTERVAL).await;
122                }
123            }
124        }
125    }
126
127    fn build_request(url: &Url, auth_header: &str) -> Result<Request> {
128        let mut request = url
129            .as_str()
130            .into_client_request()
131            .map_err(|e| anyhow::anyhow!("Failed to build request: {}", e))?;
132
133        let headers = request.headers_mut();
134        headers.insert("Authorization", auth_header.parse()?);
135        headers.insert("x-sdk", "rust-client".parse()?);
136        headers.insert("x-sdk-version", env!("CARGO_PKG_VERSION").parse()?);
137        headers.insert("Connection", "Upgrade".parse()?);
138        headers.insert("Upgrade", "websocket".parse()?);
139        headers.insert("Sec-WebSocket-Version", "13".parse()?);
140
141        Ok(request)
142    }
143
144    fn setup_tls() -> Result<Arc<ClientConfig>> {
145        if CryptoProvider::get_default().is_none() {
146            default_provider()
147                .install_default()
148                .map_err(|e| anyhow::anyhow!("Failed to install crypto provider: {:?}", e))?;
149        }
150
151        let root_store = RootCertStore {
152            roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
153        };
154
155        let tls_config = ClientConfig::builder()
156            .with_root_certificates(root_store)
157            .with_no_client_auth();
158
159        Ok(Arc::new(tls_config))
160    }
161
162    fn start_loops(
163        &self,
164        stream: Arc<Mutex<WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>>>,
165        write_rx: mpsc::Receiver<Message>,
166    ) {
167        let write_stream = stream.clone();
168        tokio::spawn(write_loop(write_stream, write_rx));
169
170        let read_stream = stream.clone();
171        let request_map = self.request_map.clone();
172        let subscriptions = self.subscriptions.clone();
173        tokio::spawn(read_loop(read_stream, request_map, subscriptions));
174
175        let ping_stream = stream;
176        let shutdown_rx = self.shutdown_tx.subscribe();
177        tokio::spawn(ping_loop(ping_stream, shutdown_rx));
178    }
179
180    pub async fn request<T>(&self, method: &str, params: Value) -> Result<T>
181    where
182        T: DeserializeOwned,
183    {
184        let request_id = self.request_id.fetch_add(1, Ordering::SeqCst);
185        let request_json = json!({
186            "jsonrpc": "2.0",
187            "id": request_id,
188            "method": method,
189            "params": params
190        });
191
192        let (tx, mut rx) = mpsc::channel(1);
193        {
194            let mut request_map = self.request_map.lock().await;
195            request_map.insert(request_id, RequestTracker { ch: tx });
196        }
197
198        let msg = Message::Text(request_json.to_string());
199        timeout(Duration::from_secs(5), self.write_tx.send(msg))
200            .await
201            .map_err(|_| anyhow::anyhow!("Request send timeout"))?
202            .map_err(|e| anyhow::anyhow!("Failed to send request: {}", e))?;
203
204        let response = timeout(Duration::from_secs(10), rx.recv())
205            .await
206            .map_err(|_| anyhow::anyhow!("Response timeout"))?
207            .ok_or_else(|| anyhow::anyhow!("Channel closed unexpectedly"))?;
208
209        let json_response: Value = serde_json::from_str(&response.response)
210            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
211
212        if let Some(error) = json_response.get("error") {
213            return Err(anyhow::anyhow!("RPC error: {}", error));
214        }
215
216        let result = json_response
217            .get("result")
218            .ok_or_else(|| anyhow::anyhow!("Missing result field in response"))?;
219
220        let mut res = result.clone();
221        convert_string_enums(&mut res);
222
223        serde_json::from_value(res).map_err(|e| anyhow::anyhow!("Failed to parse result: {}", e))
224    }
225
226    pub async fn stream_proto<Req, Resp>(
227        &self,
228        method: &str,
229        request: &Req,
230    ) -> Result<impl Stream<Item = Result<Resp>> + Unpin>
231    where
232        Req: prost::Message + Serialize,
233        Resp: prost::Message + Default + DeserializeOwned + Send + Clone + 'static,
234    {
235        let (tx, rx) = mpsc::channel(SUBSCRIPTION_BUFFER);
236
237        let params = serde_json::to_value(request)?;
238        let params_array = json!([method, params]);
239        let subscription_id: String = self.request("subscribe", params_array).await?;
240
241        {
242            let mut subs = self.subscriptions.lock().await;
243            subs.insert(subscription_id, Subscription { sender: tx });
244        }
245
246        Ok(ReceiverStream::new(rx).map(|value: Value| {
247            let mut modified_value = value;
248            convert_string_enums(&mut modified_value);
249
250            serde_json::from_value(modified_value)
251                .map_err(|e| anyhow::anyhow!("Failed to parse stream value: {}", e))
252        }))
253    }
254
255    pub async fn close(self) -> Result<()> {
256        let _ = self.shutdown_tx.send(());
257
258        {
259            let mut request_map = self.request_map.lock().await;
260            for (_, tracker) in request_map.drain() {
261                let _ = tracker
262                    .ch
263                    .send(ResponseUpdate {
264                        response: String::from("{\"error\":\"connection closed\"}"),
265                    })
266                    .await;
267            }
268        }
269
270        let mut stream = self.stream.lock().await;
271        if let Err(e) = stream.close(None).await {
272            eprintln!("Error during WebSocket close: {}", e);
273        }
274        drop(stream);
275
276        tokio::time::sleep(Duration::from_millis(100)).await;
277        println!("WebSocket shutdown complete");
278        Ok(())
279    }
280}
281
282async fn write_loop(
283    stream: Arc<Mutex<WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>>>,
284    mut write_rx: mpsc::Receiver<Message>,
285) {
286    while let Some(msg) = write_rx.recv().await {
287        let mut stream = stream.lock().await;
288        if let Err(e) = stream.send(msg).await {
289            eprintln!("Write error: {}", e);
290            break;
291        }
292    }
293}
294
295async fn read_loop(
296    stream: Arc<Mutex<WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>>>,
297    request_map: Arc<Mutex<HashMap<u64, RequestTracker>>>,
298    subscriptions: Arc<Mutex<HashMap<String, Subscription>>>,
299) {
300    loop {
301        let mut stream = stream.lock().await;
302        let Ok(Some(Ok(msg))) = timeout(Duration::from_millis(100), stream.next()).await else {
303            continue;
304        };
305
306        match msg {
307            Message::Text(text) => {
308                if let Ok(value) = serde_json::from_str(&text) {
309                    handle_message(&value, &request_map, &subscriptions, &text).await;
310                }
311            }
312            Message::Close(_) => break,
313            _ => (),
314        }
315    }
316}
317
318async fn ping_loop(
319    stream: Arc<Mutex<WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>>>,
320    mut shutdown_rx: broadcast::Receiver<()>,
321) {
322    let mut interval = tokio::time::interval(PING_INTERVAL);
323    loop {
324        tokio::select! {
325            _ = interval.tick() => {
326                let mut stream = stream.lock().await;
327                if let Err(e) = stream.send(Message::Ping(vec![])).await {
328                    eprintln!("Ping error: {}", e);
329                    break;
330                }
331            }
332            Ok(_) = shutdown_rx.recv() => break,
333        }
334    }
335}
336
337async fn handle_message(
338    value: &Value,
339    request_map: &Arc<Mutex<HashMap<u64, RequestTracker>>>,
340    subscriptions: &Arc<Mutex<HashMap<String, Subscription>>>,
341    text: &str,
342) {
343    match value.get("id").and_then(|id| id.as_u64()) {
344        Some(id) => {
345            if let Some(tracker) = request_map.lock().await.get(&id) {
346                let _ = tracker
347                    .ch
348                    .send(ResponseUpdate {
349                        response: text.to_string(),
350                    })
351                    .await;
352            }
353        }
354        None => handle_subscription(value, subscriptions).await,
355    }
356}
357
358async fn handle_subscription(
359    map: &Value,
360    subscriptions: &Arc<Mutex<HashMap<String, Subscription>>>,
361) {
362    let Some(id) = map
363        .get("params")
364        .and_then(|p| p.get("subscription"))
365        .and_then(|s| s.as_str())
366    else {
367        return;
368    };
369
370    let Some(result) = map.get("params").and_then(|p| p.get("result")) else {
371        return;
372    };
373
374    if let Some(sub) = subscriptions.lock().await.get(id) {
375        let _ = sub.sender.send(result.clone()).await;
376    }
377}