solana_trader_client_rust/provider/ws/
mod.rs

1pub mod quote;
2pub mod stream;
3pub mod swap;
4
5use anyhow::{anyhow, Result};
6use serde_json::json;
7use solana_sdk::pubkey::Pubkey;
8use solana_sdk::signature::Keypair;
9use solana_trader_proto::api::{self, PostSubmitPaladinRequest, GetRecentBlockHashResponseV2};
10
11use crate::common::signing::{sign_transaction, SubmitParams};
12use crate::common::{get_base_url_from_env, is_submit_only_endpoint, ws_endpoint, BaseConfig};
13use crate::connections::ws::WS;
14use crate::provider::utils::timestamp_rfc3339;
15
16use super::utils::IntoTransactionMessage;
17
18pub struct WebSocketConfig {
19    pub endpoint: String,
20    pub private_key: Option<Keypair>,
21    pub auth_header: String,
22    pub use_tls: bool,
23    pub disable_auth: bool,
24}
25
26pub struct WebSocketClient {
27    conn: WS,
28    keypair: Option<Keypair>,
29    pub public_key: Option<Pubkey>,
30}
31
32impl WebSocketClient {
33    pub fn get_keypair(&self) -> Result<&Keypair> {
34        Ok(self.keypair.as_ref().unwrap())
35    }
36
37    pub async fn new(endpoint: Option<String>) -> Result<Self> {
38        let base = BaseConfig::try_from_env()?;
39        let (default_base_url, secure) = get_base_url_from_env();
40        let final_base_url = endpoint.unwrap_or(default_base_url);
41        let endpoint = ws_endpoint(&final_base_url, secure);
42
43        is_submit_only_endpoint(&final_base_url);
44
45        if base.auth_header.is_empty() {
46            return Err(anyhow::anyhow!("AUTH_HEADER is empty"));
47        }
48
49        let conn = WS::new(Some(endpoint))
50            .await
51            .map_err(|e| anyhow::anyhow!("Connection timeout: {}", e))?;
52
53        Ok(Self {
54            conn,
55            keypair: base.keypair,
56            public_key: base.public_key,
57        })
58    }
59
60    pub async fn close(self) -> Result<()> {
61        self.conn.close().await
62    }
63
64    pub async fn sign_and_submit<T: IntoTransactionMessage + Clone>(
65        &self,
66        txs: Vec<T>,
67        submit_opts: SubmitParams,
68        use_bundle: bool,
69    ) -> Result<Vec<String>> {
70        let keypair = self.get_keypair()?;
71
72        let hash_res: GetRecentBlockHashResponseV2 =
73            self.conn.request("GetRecentBlockHashV2", json!({})).await?;
74
75        if txs.len() == 1 {
76            let signed_tx = sign_transaction(&txs[0], keypair, hash_res.block_hash).await?;
77
78            let request = json!({
79                "transaction": {
80                    "content": signed_tx.content,
81                    "isCleanup": signed_tx.is_cleanup
82                },
83                "skipPreFlight": submit_opts.skip_pre_flight,
84                "frontRunningProtection": submit_opts.front_running_protection,
85                "useStakedRPCs": submit_opts.use_staked_rpcs,
86                "fastBestEffort": submit_opts.fast_best_effort
87            });
88
89            let response: serde_json::Value = self.conn.request("PostSubmitV2", request).await?;
90
91            return Ok(vec![response
92                .get("signature")
93                .and_then(|s| s.as_str())
94                .map(String::from)
95                .ok_or_else(|| anyhow!("Missing signature in response"))?]);
96        }
97
98        let mut entries = Vec::with_capacity(txs.len());
99        for tx in txs {
100            let signed_tx = sign_transaction(&tx, keypair, hash_res.block_hash.clone()).await?;
101            entries.push(json!({
102                "transaction": {
103                    "content": signed_tx.content,
104                    "isCleanup": signed_tx.is_cleanup
105                },
106                "skipPreFlight": submit_opts.skip_pre_flight,
107                "frontRunningProtection": submit_opts.front_running_protection,
108                "useStakedRPCs": submit_opts.use_staked_rpcs,
109                "fastBestEffort": submit_opts.fast_best_effort
110            }));
111        }
112
113        let request = json!({
114            "entries": entries,
115            "useBundle": use_bundle,
116            "submitStrategy": submit_opts.submit_strategy
117        });
118
119        let response: serde_json::Value = self.conn.request("PostSubmitBatchV2", request).await?;
120
121        let signatures = response["transactions"]
122            .as_array()
123            .ok_or_else(|| anyhow!("Invalid response format"))?
124            .iter()
125            .filter(|entry| entry["submitted"].as_bool().unwrap_or(false))
126            .filter_map(|entry| entry["signature"].as_str().map(String::from))
127            .collect();
128
129        Ok(signatures)
130    }
131
132    pub async fn sign_and_submit_snipe<T: IntoTransactionMessage + Clone>(
133        &self,
134        txs: Vec<T>,
135        use_staked_rpcs: bool,
136    ) -> Result<Vec<String>> {
137        let keypair = self.get_keypair()?;
138
139        let hash_res: GetRecentBlockHashResponseV2 =
140            self.conn.request("GetRecentBlockHashV2", json!({})).await?;
141
142        // Build entries for each transaction
143        let mut entries = Vec::with_capacity(txs.len());
144        for tx in txs {
145            let signed_tx = sign_transaction(&tx, keypair, hash_res.block_hash.clone()).await?;
146            entries.push(json!({
147                "transaction": {
148                    "content": signed_tx.content,
149                    "isCleanup": signed_tx.is_cleanup
150                },
151                "skipPreFlight": false
152            }));
153        }
154
155        let request = json!({
156            "entries": entries,
157            "useStakedRPCs": use_staked_rpcs,
158            "timestamp": timestamp_rfc3339()
159        });
160
161        let response: serde_json::Value = self.conn.request("PostSubmitSnipeV2", request).await?;
162
163        let signatures = response["transactions"]
164            .as_array()
165            .ok_or_else(|| anyhow!("Invalid response format"))?
166            .iter()
167            .filter(|entry| entry["submitted"].as_bool().unwrap_or(false))
168            .filter_map(|entry| entry["signature"].as_str().map(String::from))
169            .collect();
170
171        Ok(signatures)
172    }
173
174    pub async fn sign_and_submit_paladin<T: IntoTransactionMessage + Clone>(
175        &self,
176        tx: T,
177        revert_protection: bool,
178    ) -> Result<String> {
179        let hash_res: GetRecentBlockHashResponseV2 =
180            self.conn.request("GetRecentBlockHashV2", json!({})).await?;
181
182        let keypair = self.get_keypair()?;
183        let signed_tx = sign_transaction(&tx, keypair, hash_res.block_hash).await?;
184
185        let request = json!({
186            "transaction": {
187                "content": signed_tx.content,
188            },
189            "revertProtection": revert_protection,
190            "timestamp": timestamp_rfc3339()
191        });
192
193        let response: serde_json::Value = self.conn.request("PostSubmitPaladinV2", request).await?;
194
195        let signature = response
196            .get("signature")
197            .and_then(|s| s.as_str())
198            .map(String::from)
199            .ok_or_else(|| anyhow!("Missing signature in response"))?;
200
201        Ok(signature)
202    }
203
204    pub async fn get_transaction(
205        &self,
206        request: api::GetTransactionRequest,
207    ) -> anyhow::Result<api::GetTransactionResponse> {
208        let params = serde_json::to_value(request)
209            .map_err(|e| anyhow::anyhow!("Failed to serialize request: {}", e))?;
210
211        self.conn.request("GetTransaction", params).await
212    }
213
214    pub async fn get_recent_block_hash(
215        &self,
216        request: api::GetRecentBlockHashRequest,
217    ) -> anyhow::Result<api::GetRecentBlockHashResponse> {
218        let params = serde_json::to_value(request)
219            .map_err(|e| anyhow::anyhow!("Failed to serialize request: {}", e))?;
220
221        self.conn.request("GetRecentBlockHash", params).await
222    }
223
224    pub async fn get_recent_block_hash_v2(
225        &self,
226        request: &api::GetRecentBlockHashRequestV2,
227    ) -> anyhow::Result<api::GetRecentBlockHashResponseV2> {
228        let params = serde_json::to_value(request)
229            .map_err(|e| anyhow::anyhow!("Failed to serialize request: {}", e))?;
230
231        self.conn.request("GetRecentBlockHashV2", params).await
232    }
233    pub async fn get_rate_limit(
234        &self,
235        request: api::GetRateLimitRequest,
236    ) -> anyhow::Result<api::GetRateLimitResponse> {
237        let params = serde_json::to_value(request)
238            .map_err(|e| anyhow::anyhow!("Failed to serialize request: {}", e))?;
239
240        self.conn.request("GetRateLimit", params).await
241    }
242
243    pub async fn get_account_balance_v2(
244        &self,
245        request: api::GetAccountBalanceRequest,
246    ) -> anyhow::Result<api::GetAccountBalanceResponse> {
247        let params = serde_json::to_value(request)
248            .map_err(|e| anyhow::anyhow!("Failed to serialize request: {}", e))?;
249
250        self.conn.request("GetAccountBalanceV2", params).await
251    }
252
253    pub async fn get_priority_fee(
254        &self,
255        project: api::Project,
256        percentile: Option<f64>,
257    ) -> Result<api::GetPriorityFeeResponse> {
258        let request = api::GetPriorityFeeRequest {
259            project: project as i32,
260            percentile,
261        };
262
263        let params = serde_json::to_value(request)
264            .map_err(|e| anyhow::anyhow!("Failed to serialize request: {}", e))?;
265
266        self.conn.request("GetPriorityFee", params).await
267    }
268
269    pub async fn get_priority_fee_by_program(
270        &self,
271        programs: Vec<String>,
272    ) -> Result<api::GetPriorityFeeByProgramResponse> {
273        let request = api::GetPriorityFeeByProgramRequest { programs };
274
275        let params = serde_json::to_value(request)
276            .map_err(|e| anyhow::anyhow!("Failed to serialize request: {}", e))?;
277
278        self.conn.request("GetPriorityFeeByProgram", params).await
279    }
280
281    pub async fn get_token_accounts(
282        &self,
283        owner_address: String,
284    ) -> Result<api::GetTokenAccountsResponse> {
285        let request = api::GetTokenAccountsRequest { owner_address };
286
287        let params = serde_json::to_value(request)
288            .map_err(|e| anyhow::anyhow!("Failed to serialize request: {}", e))?;
289
290        self.conn.request("GetTokenAccounts", params).await
291    }
292
293    pub async fn get_account_balance(
294        &self,
295        owner_address: String,
296    ) -> Result<api::GetAccountBalanceResponse> {
297        let request = api::GetAccountBalanceRequest { owner_address };
298
299        let params = serde_json::to_value(request)
300            .map_err(|e| anyhow::anyhow!("Failed to serialize request: {}", e))?;
301
302        self.conn.request("GetAccountBalance", params).await
303    }
304
305    pub async fn get_server_time(
306        &self
307    ) -> Result<api::GetServerTimeResponse> {
308        self.conn.request("GetServerTime", json!({})).await
309    }
310
311    pub async fn post_submit(
312        &self,
313        request: &api::PostSubmitRequest
314    ) -> Result<api::PostSubmitResponse> {
315        let params = json!({
316            "transaction": request.transaction,
317            "skipPreFlight": request.skip_pre_flight,
318            "frontRunningProtection": request.front_running_protection,
319            "tip": request.tip,
320            "useStakedRPCs": request.use_staked_rp_cs,
321            "fastBestEffort": request.fast_best_effort,
322            "allowBackRun": request.allow_back_run,
323            "revenueAddress": request.revenue_address,
324            "sniping": request.sniping,
325            "timestamp": timestamp_rfc3339()
326        });
327        self.conn.request("PostSubmit", params).await
328    }
329
330    pub async fn post_submit_paladin_v2(
331        &mut self,
332        request: &PostSubmitPaladinRequest,
333    ) -> Result<api::PostSubmitResponse> {
334
335        let params = json!({
336            "transaction": request.transaction,
337            "revertProtection": request.revert_protection,
338            "timestamp": timestamp_rfc3339()
339        });
340        self.conn.request("PostSubmitPaladinV2", params).await
341    }
342
343    pub async fn post_submit_v2(
344        &self,
345        request: &api::PostSubmitRequest
346    ) -> Result<api::PostSubmitResponse> {
347        let params = json!({
348            "transaction": request.transaction,
349            "skipPreFlight": request.skip_pre_flight,
350            "frontRunningProtection": request.front_running_protection,
351            "tip": request.tip,
352            "useStakedRPCs": request.use_staked_rp_cs,
353            "fastBestEffort": request.fast_best_effort,
354            "allowBackRun": request.allow_back_run,
355            "revenueAddress": request.revenue_address,
356            "sniping": request.sniping,
357            "timestamp": timestamp_rfc3339()
358        });
359        self.conn.request("PostSubmitV2", params).await
360    }
361
362    pub async fn post_submit_batch(
363        &self,
364        request: &api::PostSubmitBatchRequest
365    ) -> Result<api::PostSubmitBatchResponse> {
366        let params = json!({
367            "entries": request.entries,
368            "submitStrategy": request.submit_strategy,
369            "useBundle": request.use_bundle,
370            "frontRunningProtection": request.front_running_protection,
371            "timestamp": timestamp_rfc3339()
372        });
373        self.conn.request("PostSubmitBatch", params).await
374    }
375}