shadow_rpc_auth/
http_sender.rs1use reqwest::header::HeaderMap;
5use serde::Deserialize;
6use serde_json::{json, Value};
7use solana_client::rpc_custom_error as custom_error;
8use solana_client::rpc_request::{RpcError, RpcRequest, RpcResponseErrorData};
9use solana_client::rpc_response::RpcSimulateTransactionResult;
10use solana_client::rpc_sender::{RpcSender, RpcTransportStats};
11use {
12 async_trait::async_trait,
13 log::*,
14 reqwest::{
15 self,
16 header::{self, CONTENT_TYPE, RETRY_AFTER},
17 StatusCode,
18 },
19 std::{
20 sync::{
21 atomic::{AtomicU64, Ordering},
22 Arc, RwLock,
23 },
24 time::{Duration, Instant},
25 },
26 tokio::time::sleep,
27};
28
29#[derive(Deserialize, Debug)]
30pub struct RpcErrorObject {
31 pub code: i64,
32 pub message: String,
33}
34
35pub struct HttpSenderWithHeaders {
36 client: Arc<reqwest::Client>,
37 url: String,
38 request_id: AtomicU64,
39 stats: RwLock<RpcTransportStats>,
40}
41
42impl HttpSenderWithHeaders {
44 pub fn new<U: ToString>(url: U, headers: Option<HeaderMap>) -> Self {
49 Self::new_with_timeout(url, Duration::from_secs(30), headers)
50 }
51
52 pub fn new_with_timeout<U: ToString>(
56 url: U,
57 timeout: Duration,
58 headers: Option<HeaderMap>,
59 ) -> Self {
60 let mut default_headers = header::HeaderMap::new();
61 default_headers.append(
62 header::HeaderName::from_static("solana-client"),
63 header::HeaderValue::from_str(
64 format!("rust/{}", solana_version::Version::default()).as_str(),
65 )
66 .unwrap(),
67 );
68 if let Some(headers) = headers {
69 default_headers.extend(headers);
70 }
71
72 let client = Arc::new(
73 reqwest::Client::builder()
74 .default_headers(default_headers)
75 .timeout(timeout)
76 .pool_idle_timeout(timeout)
77 .build()
78 .expect("build rpc client"),
79 );
80
81 Self {
82 client,
83 url: url.to_string(),
84 request_id: AtomicU64::new(0),
85 stats: RwLock::new(RpcTransportStats::default()),
86 }
87 }
88}
89
90struct StatsUpdater<'a> {
91 stats: &'a RwLock<RpcTransportStats>,
92 request_start_time: Instant,
93 rate_limited_time: Duration,
94}
95
96impl<'a> StatsUpdater<'a> {
97 fn new(stats: &'a RwLock<RpcTransportStats>) -> Self {
98 Self {
99 stats,
100 request_start_time: Instant::now(),
101 rate_limited_time: Duration::default(),
102 }
103 }
104
105 fn add_rate_limited_time(&mut self, duration: Duration) {
106 self.rate_limited_time += duration;
107 }
108}
109
110impl<'a> Drop for StatsUpdater<'a> {
111 fn drop(&mut self) {
112 let mut stats = self.stats.write().unwrap();
113 stats.request_count += 1;
114 stats.elapsed_time += Instant::now().duration_since(self.request_start_time);
115 stats.rate_limited_time += self.rate_limited_time;
116 }
117}
118
119pub fn build_request_json(req: &RpcRequest, id: u64, params: Value) -> Value {
120 let jsonrpc = "2.0";
121 json!({
122 "jsonrpc": jsonrpc,
123 "id": id,
124 "method": format!("{}", req),
125 "params": params,
126 })
127}
128
129#[async_trait]
130impl RpcSender for HttpSenderWithHeaders {
131 fn get_transport_stats(&self) -> RpcTransportStats {
132 self.stats.read().unwrap().clone()
133 }
134
135 async fn send(
136 &self,
137 request: RpcRequest,
138 params: serde_json::Value,
139 ) -> solana_client::client_error::Result<serde_json::Value> {
140 let mut stats_updater = StatsUpdater::new(&self.stats);
141
142 let request_id = self.request_id.fetch_add(1, Ordering::Relaxed);
143 let request_json = build_request_json(&request, request_id, params).to_string();
144
145 let mut too_many_requests_retries = 5;
146 loop {
147 let response = {
148 let client = self.client.clone();
149 let request_json = request_json.clone();
150 client
151 .post(&self.url)
152 .header(CONTENT_TYPE, "application/json")
153 .body(request_json)
154 .send()
155 .await
156 }?;
157
158 if !response.status().is_success() {
159 if response.status() == StatusCode::TOO_MANY_REQUESTS
160 && too_many_requests_retries > 0
161 {
162 let mut duration = Duration::from_millis(500);
163 if let Some(retry_after) = response.headers().get(RETRY_AFTER) {
164 if let Ok(retry_after) = retry_after.to_str() {
165 if let Ok(retry_after) = retry_after.parse::<u64>() {
166 if retry_after < 120 {
167 duration = Duration::from_secs(retry_after);
168 }
169 }
170 }
171 }
172
173 too_many_requests_retries -= 1;
174 debug!(
175 "Too many requests: server responded with {:?}, {} retries left, pausing for {:?}",
176 response, too_many_requests_retries, duration
177 );
178
179 sleep(duration).await;
180 stats_updater.add_rate_limited_time(duration);
181 continue;
182 }
183 return Err(response.error_for_status().unwrap_err().into());
184 }
185
186 let mut json = response.json::<serde_json::Value>().await?;
187 if json["error"].is_object() {
188 return match serde_json::from_value::<RpcErrorObject>(json["error"].clone()) {
189 Ok(rpc_error_object) => {
190 let data = match rpc_error_object.code {
191 custom_error::JSON_RPC_SERVER_ERROR_SEND_TRANSACTION_PREFLIGHT_FAILURE => {
192 match serde_json::from_value::<RpcSimulateTransactionResult>(json["error"]["data"].clone()) {
193 Ok(data) => RpcResponseErrorData::SendTransactionPreflightFailure(data),
194 Err(err) => {
195 debug!("Failed to deserialize RpcSimulateTransactionResult: {:?}", err);
196 RpcResponseErrorData::Empty
197 }
198 }
199 },
200 custom_error::JSON_RPC_SERVER_ERROR_NODE_UNHEALTHY => {
201 match serde_json::from_value::<custom_error::NodeUnhealthyErrorData>(json["error"]["data"].clone()) {
202 Ok(custom_error::NodeUnhealthyErrorData {num_slots_behind}) => RpcResponseErrorData::NodeUnhealthy {num_slots_behind},
203 Err(_err) => {
204 RpcResponseErrorData::Empty
205 }
206 }
207 },
208 _ => RpcResponseErrorData::Empty
209 };
210
211 Err(RpcError::RpcResponseError {
212 code: rpc_error_object.code,
213 message: rpc_error_object.message,
214 data,
215 }
216 .into())
217 }
218 Err(err) => Err(RpcError::RpcRequestError(format!(
219 "Failed to deserialize RPC error response: {} [{}]",
220 serde_json::to_string(&json["error"]).unwrap(),
221 err
222 ))
223 .into()),
224 };
225 }
226 return Ok(json["result"].take());
227 }
228 }
229
230 fn url(&self) -> String {
231 self.url.clone()
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238
239 #[tokio::test(flavor = "multi_thread")]
240 async fn http_sender_on_tokio_multi_thread() {
241 let http_sender = HttpSenderWithHeaders::new("http://localhost:1234".to_string(), None);
242 let _ = http_sender
243 .send(RpcRequest::GetVersion, serde_json::Value::Null)
244 .await;
245 }
246
247 #[tokio::test(flavor = "current_thread")]
248 async fn http_sender_on_tokio_current_thread() {
249 let http_sender = HttpSenderWithHeaders::new("http://localhost:1234".to_string(), None);
250 let _ = http_sender
251 .send(RpcRequest::GetVersion, serde_json::Value::Null)
252 .await;
253 }
254}