solana_trader_client_rust/connections/
ws.rs

1use anyhow::Result;
2use futures_util::{SinkExt, StreamExt};
3use rustls::crypto::ring::default_provider;
4use rustls::crypto::CryptoProvider;
5use serde::de::DeserializeOwned;
6use serde::Serialize;
7use serde_json::{json, Value};
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::net::TcpStream;
13use tokio::sync::mpsc::Sender;
14use tokio::sync::{broadcast, mpsc, Mutex};
15use tokio::time::timeout;
16use tokio_rustls::rustls::{ClientConfig, RootCertStore};
17use tokio_stream::wrappers::ReceiverStream;
18use tokio_stream::Stream;
19use tokio_tungstenite::tungstenite::client::IntoClientRequest;
20use tokio_tungstenite::tungstenite::handshake::client::Request;
21use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
22use tokio_tungstenite::{connect_async_tls_with_config, Connector};
23use tokio_tungstenite::{tungstenite::protocol::Message, WebSocketStream};
24use url::Url;
25
26use crate::common::constants::WARNING_TLS_SLOWDOWN;
27use crate::common::{get_base_url_from_env, ws_endpoint, BaseConfig};
28use crate::provider::utils::convert_string_enums;
29
30const CONNECTION_RETRY_TIMEOUT: Duration = Duration::from_secs(15);
31const CONNECTION_RETRY_INTERVAL: Duration = Duration::from_millis(100);
32const SUBSCRIPTION_BUFFER: usize = 1000;
33const PING_INTERVAL: Duration = Duration::from_secs(15);
34
35#[derive(Debug)]
36pub struct Subscription {
37    sender: mpsc::Sender<Value>,
38}
39
40#[derive(Clone)]
41struct ResponseUpdate {
42    response: String,
43}
44
45struct RequestTracker {
46    ch: Sender<ResponseUpdate>,
47}
48
49pub struct WS {
50    stream: Arc<Mutex<WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>>>,
51    write_tx: Sender<Message>,
52    shutdown_tx: broadcast::Sender<()>,
53    request_id: AtomicU64,
54    request_map: Arc<Mutex<HashMap<u64, RequestTracker>>>,
55    subscriptions: Arc<Mutex<HashMap<String, Subscription>>>,
56}
57
58impl WS {
59    pub async fn new(endpoint: Option<String>) -> Result<Self> {
60        let base = BaseConfig::try_from_env()?;
61        let (base_url, secure) = get_base_url_from_env();
62        let endpoint = endpoint.unwrap_or_else(|| ws_endpoint(&base_url, secure));
63        if endpoint.starts_with("wss://") {
64            println!("{}", WARNING_TLS_SLOWDOWN);
65        }
66
67        if base.auth_header.is_empty() {
68            return Err(anyhow::anyhow!("AUTH_HEADER is empty"));
69        }
70
71        let url =
72            Url::parse(&endpoint).map_err(|e| anyhow::anyhow!("Invalid WebSocket URL: {}", e))?;
73
74        let stream = Self::connect(&url, &base.auth_header).await?;
75        let stream = Arc::new(Mutex::new(stream));
76
77        let (write_tx, write_rx) = mpsc::channel(100);
78        let (shutdown_tx, _) = broadcast::channel(1);
79
80        let ws = Self {
81            stream: stream.clone(),
82            write_tx,
83            shutdown_tx,
84            request_id: AtomicU64::new(0),
85            request_map: Arc::new(Mutex::new(HashMap::new())),
86            subscriptions: Arc::new(Mutex::new(HashMap::new())),
87        };
88
89        ws.start_loops(stream, write_rx);
90        Ok(ws)
91    }
92
93    async fn connect(
94        url: &Url,
95        auth_header: &str,
96    ) -> Result<WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>> {
97        let request = Self::build_request(url, auth_header)?;
98
99        let mut retry_count = 0;
100        let max_retries =
101            (CONNECTION_RETRY_TIMEOUT.as_millis() / CONNECTION_RETRY_INTERVAL.as_millis()) as u32;
102
103        loop {
104            match connect_async_tls_with_config(
105                request.clone(),
106                Some(WebSocketConfig::default()),
107                true,
108                Some(Connector::Rustls(Self::setup_tls()?)),
109            )
110            .await
111            {
112                Ok((stream, _)) => {
113                    println!("Connected to: {}", url);
114                    return Ok(stream);
115                }
116                Err(e) => {
117                    if retry_count >= max_retries {
118                        return Err(anyhow::anyhow!(
119                            "WebSocket connection failed after {} retries: {}",
120                            max_retries,
121                            e
122                        ));
123                    }
124                    retry_count += 1;
125                    tokio::time::sleep(CONNECTION_RETRY_INTERVAL).await;
126                }
127            }
128        }
129    }
130
131    fn build_request(url: &Url, auth_header: &str) -> Result<Request> {
132        let mut request = url
133            .as_str()
134            .into_client_request()
135            .map_err(|e| anyhow::anyhow!("Failed to build request: {}", e))?;
136
137        let headers = request.headers_mut();
138        headers.insert("Authorization", auth_header.parse()?);
139        headers.insert("x-sdk", "rust-client".parse()?);
140        headers.insert("x-sdk-version", env!("CARGO_PKG_VERSION").parse()?);
141        headers.insert("Connection", "Upgrade".parse()?);
142        headers.insert("Upgrade", "websocket".parse()?);
143        headers.insert("Sec-WebSocket-Version", "13".parse()?);
144
145        Ok(request)
146    }
147
148    fn setup_tls() -> Result<Arc<ClientConfig>> {
149        if CryptoProvider::get_default().is_none() {
150            default_provider()
151                .install_default()
152                .map_err(|e| anyhow::anyhow!("Failed to install crypto provider: {:?}", e))?;
153        }
154
155        let root_store = RootCertStore {
156            roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
157        };
158
159        let tls_config = ClientConfig::builder()
160            .with_root_certificates(root_store)
161            .with_no_client_auth();
162
163        Ok(Arc::new(tls_config))
164    }
165
166    fn start_loops(
167        &self,
168        stream: Arc<Mutex<WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>>>,
169        write_rx: mpsc::Receiver<Message>,
170    ) {
171        let write_stream = stream.clone();
172        tokio::spawn(write_loop(write_stream, write_rx));
173
174        let read_stream = stream.clone();
175        let request_map = self.request_map.clone();
176        let subscriptions = self.subscriptions.clone();
177        tokio::spawn(read_loop(read_stream, request_map, subscriptions));
178
179        let ping_stream = stream;
180        let shutdown_rx = self.shutdown_tx.subscribe();
181        tokio::spawn(ping_loop(ping_stream, shutdown_rx));
182    }
183
184    pub async fn request<T>(&self, method: &str, params: Value) -> Result<T>
185    where
186        T: DeserializeOwned,
187    {
188        let request_id = self.request_id.fetch_add(1, Ordering::SeqCst);
189        let request_json = json!({
190            "jsonrpc": "2.0",
191            "id": request_id,
192            "method": method,
193            "params": params
194        });
195
196        let (tx, mut rx) = mpsc::channel(1);
197        {
198            let mut request_map = self.request_map.lock().await;
199            request_map.insert(request_id, RequestTracker { ch: tx });
200        }
201
202        let msg = Message::Text(request_json.to_string());
203        timeout(Duration::from_secs(5), self.write_tx.send(msg))
204            .await
205            .map_err(|_| anyhow::anyhow!("Request send timeout"))?
206            .map_err(|e| anyhow::anyhow!("Failed to send request: {}", e))?;
207
208        let response = timeout(Duration::from_secs(10), rx.recv())
209            .await
210            .map_err(|_| anyhow::anyhow!("Response timeout"))?
211            .ok_or_else(|| anyhow::anyhow!("Channel closed unexpectedly"))?;
212
213        let json_response: Value = serde_json::from_str(&response.response)
214            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
215
216        if let Some(error) = json_response.get("error") {
217            return Err(anyhow::anyhow!("RPC error: {}", error));
218        }
219
220        let result = json_response
221            .get("result")
222            .ok_or_else(|| anyhow::anyhow!("Missing result field in response"))?;
223
224        let mut res = result.clone();
225        convert_string_enums(&mut res);
226
227        serde_json::from_value(res).map_err(|e| anyhow::anyhow!("Failed to parse result: {}", e))
228    }
229
230    pub async fn stream_proto<Req, Resp>(
231        &self,
232        method: &str,
233        request: &Req,
234    ) -> Result<impl Stream<Item = Result<Resp>> + Unpin>
235    where
236        Req: prost::Message + Serialize,
237        Resp: prost::Message + Default + DeserializeOwned + Send + Clone + 'static,
238    {
239        let (tx, rx) = mpsc::channel(SUBSCRIPTION_BUFFER);
240
241        let params = serde_json::to_value(request)?;
242        let params_array = json!([method, params]);
243        let subscription_id: String = self.request("subscribe", params_array).await?;
244
245        {
246            let mut subs = self.subscriptions.lock().await;
247            subs.insert(subscription_id, Subscription { sender: tx });
248        }
249
250        Ok(ReceiverStream::new(rx).map(|value: Value| {
251            let mut modified_value = value;
252            convert_string_enums(&mut modified_value);
253
254            serde_json::from_value(modified_value)
255                .map_err(|e| anyhow::anyhow!("Failed to parse stream value: {}", e))
256        }))
257    }
258
259    pub async fn close(self) -> Result<()> {
260        let _ = self.shutdown_tx.send(());
261
262        {
263            let mut request_map = self.request_map.lock().await;
264            for (_, tracker) in request_map.drain() {
265                let _ = tracker
266                    .ch
267                    .send(ResponseUpdate {
268                        response: String::from("{\"error\":\"connection closed\"}"),
269                    })
270                    .await;
271            }
272        }
273
274        let mut stream = self.stream.lock().await;
275        if let Err(e) = stream.close(None).await {
276            eprintln!("Error during WebSocket close: {}", e);
277        }
278        drop(stream);
279
280        tokio::time::sleep(Duration::from_millis(100)).await;
281        println!("WebSocket shutdown complete");
282        Ok(())
283    }
284}
285
286async fn write_loop(
287    stream: Arc<Mutex<WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>>>,
288    mut write_rx: mpsc::Receiver<Message>,
289) {
290    while let Some(msg) = write_rx.recv().await {
291        let mut stream = stream.lock().await;
292        if let Err(e) = stream.send(msg).await {
293            eprintln!("Write error: {}", e);
294            break;
295        }
296    }
297}
298
299async fn read_loop(
300    stream: Arc<Mutex<WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>>>,
301    request_map: Arc<Mutex<HashMap<u64, RequestTracker>>>,
302    subscriptions: Arc<Mutex<HashMap<String, Subscription>>>,
303) {
304    loop {
305        let mut stream = stream.lock().await;
306        let Ok(Some(Ok(msg))) = timeout(Duration::from_millis(100), stream.next()).await else {
307            continue;
308        };
309
310        match msg {
311            Message::Text(text) => {
312                if let Ok(value) = serde_json::from_str(&text) {
313                    handle_message(&value, &request_map, &subscriptions, &text).await;
314                }
315            }
316            Message::Close(_) => break,
317            _ => (),
318        }
319    }
320}
321
322async fn ping_loop(
323    stream: Arc<Mutex<WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>>>,
324    mut shutdown_rx: broadcast::Receiver<()>,
325) {
326    let mut interval = tokio::time::interval(PING_INTERVAL);
327    loop {
328        tokio::select! {
329            _ = interval.tick() => {
330                let mut stream = stream.lock().await;
331                if let Err(e) = stream.send(Message::Ping(vec![])).await {
332                    eprintln!("Ping error: {}", e);
333                    break;
334                }
335            }
336            Ok(_) = shutdown_rx.recv() => break,
337        }
338    }
339}
340
341async fn handle_message(
342    value: &Value,
343    request_map: &Arc<Mutex<HashMap<u64, RequestTracker>>>,
344    subscriptions: &Arc<Mutex<HashMap<String, Subscription>>>,
345    text: &str,
346) {
347    match value.get("id").and_then(|id| id.as_u64()) {
348        Some(id) => {
349            if let Some(tracker) = request_map.lock().await.get(&id) {
350                let _ = tracker
351                    .ch
352                    .send(ResponseUpdate {
353                        response: text.to_string(),
354                    })
355                    .await;
356            }
357        }
358        None => handle_subscription(value, subscriptions).await,
359    }
360}
361
362async fn handle_subscription(
363    map: &Value,
364    subscriptions: &Arc<Mutex<HashMap<String, Subscription>>>,
365) {
366    let Some(id) = map
367        .get("params")
368        .and_then(|p| p.get("subscription"))
369        .and_then(|s| s.as_str())
370    else {
371        return;
372    };
373
374    let Some(result) = map.get("params").and_then(|p| p.get("result")) else {
375        return;
376    };
377
378    if let Some(sub) = subscriptions.lock().await.get(id) {
379        let _ = sub.sender.send(result.clone()).await;
380    }
381}