shadow_rpc_auth/
http_sender.rs

1/// Copied from `solana-rpc-client` crate, modified [HttpSender]
2/// to allow for passing in default headers. This is useful for
3/// passing auth headers to RPC services like GenesysGo.
4use reqwest::header::HeaderMap;
5use serde::Deserialize;
6use serde_json::{json, Value};
7use solana_client::rpc_custom_error as custom_error;
8use solana_client::rpc_request::{RpcError, RpcRequest, RpcResponseErrorData};
9use solana_client::rpc_response::RpcSimulateTransactionResult;
10use solana_client::rpc_sender::{RpcSender, RpcTransportStats};
11use {
12    async_trait::async_trait,
13    log::*,
14    reqwest::{
15        self,
16        header::{self, CONTENT_TYPE, RETRY_AFTER},
17        StatusCode,
18    },
19    std::{
20        sync::{
21            atomic::{AtomicU64, Ordering},
22            Arc, RwLock,
23        },
24        time::{Duration, Instant},
25    },
26    tokio::time::sleep,
27};
28
29#[derive(Deserialize, Debug)]
30pub struct RpcErrorObject {
31    pub code: i64,
32    pub message: String,
33}
34
35pub struct HttpSenderWithHeaders {
36    client: Arc<reqwest::Client>,
37    url: String,
38    request_id: AtomicU64,
39    stats: RwLock<RpcTransportStats>,
40}
41
42/// Nonblocking [`RpcSender`] over HTTP.
43impl HttpSenderWithHeaders {
44    /// Create an HTTP RPC sender.
45    ///
46    /// The URL is an HTTP URL, usually for port 8899, as in
47    /// "http://localhost:8899". The sender has a default timeout of 30 seconds.
48    pub fn new<U: ToString>(url: U, headers: Option<HeaderMap>) -> Self {
49        Self::new_with_timeout(url, Duration::from_secs(30), headers)
50    }
51
52    /// Create an HTTP RPC sender.
53    ///
54    /// The URL is an HTTP URL, usually for port 8899.
55    pub fn new_with_timeout<U: ToString>(
56        url: U,
57        timeout: Duration,
58        headers: Option<HeaderMap>,
59    ) -> Self {
60        let mut default_headers = header::HeaderMap::new();
61        default_headers.append(
62            header::HeaderName::from_static("solana-client"),
63            header::HeaderValue::from_str(
64                format!("rust/{}", solana_version::Version::default()).as_str(),
65            )
66            .unwrap(),
67        );
68        if let Some(headers) = headers {
69            default_headers.extend(headers);
70        }
71
72        let client = Arc::new(
73            reqwest::Client::builder()
74                .default_headers(default_headers)
75                .timeout(timeout)
76                .pool_idle_timeout(timeout)
77                .build()
78                .expect("build rpc client"),
79        );
80
81        Self {
82            client,
83            url: url.to_string(),
84            request_id: AtomicU64::new(0),
85            stats: RwLock::new(RpcTransportStats::default()),
86        }
87    }
88}
89
90struct StatsUpdater<'a> {
91    stats: &'a RwLock<RpcTransportStats>,
92    request_start_time: Instant,
93    rate_limited_time: Duration,
94}
95
96impl<'a> StatsUpdater<'a> {
97    fn new(stats: &'a RwLock<RpcTransportStats>) -> Self {
98        Self {
99            stats,
100            request_start_time: Instant::now(),
101            rate_limited_time: Duration::default(),
102        }
103    }
104
105    fn add_rate_limited_time(&mut self, duration: Duration) {
106        self.rate_limited_time += duration;
107    }
108}
109
110impl<'a> Drop for StatsUpdater<'a> {
111    fn drop(&mut self) {
112        let mut stats = self.stats.write().unwrap();
113        stats.request_count += 1;
114        stats.elapsed_time += Instant::now().duration_since(self.request_start_time);
115        stats.rate_limited_time += self.rate_limited_time;
116    }
117}
118
119pub fn build_request_json(req: &RpcRequest, id: u64, params: Value) -> Value {
120    let jsonrpc = "2.0";
121    json!({
122       "jsonrpc": jsonrpc,
123       "id": id,
124       "method": format!("{}", req),
125       "params": params,
126    })
127}
128
129#[async_trait]
130impl RpcSender for HttpSenderWithHeaders {
131    fn get_transport_stats(&self) -> RpcTransportStats {
132        self.stats.read().unwrap().clone()
133    }
134
135    async fn send(
136        &self,
137        request: RpcRequest,
138        params: serde_json::Value,
139    ) -> solana_client::client_error::Result<serde_json::Value> {
140        let mut stats_updater = StatsUpdater::new(&self.stats);
141
142        let request_id = self.request_id.fetch_add(1, Ordering::Relaxed);
143        let request_json = build_request_json(&request, request_id, params).to_string();
144
145        let mut too_many_requests_retries = 5;
146        loop {
147            let response = {
148                let client = self.client.clone();
149                let request_json = request_json.clone();
150                client
151                    .post(&self.url)
152                    .header(CONTENT_TYPE, "application/json")
153                    .body(request_json)
154                    .send()
155                    .await
156            }?;
157
158            if !response.status().is_success() {
159                if response.status() == StatusCode::TOO_MANY_REQUESTS
160                    && too_many_requests_retries > 0
161                {
162                    let mut duration = Duration::from_millis(500);
163                    if let Some(retry_after) = response.headers().get(RETRY_AFTER) {
164                        if let Ok(retry_after) = retry_after.to_str() {
165                            if let Ok(retry_after) = retry_after.parse::<u64>() {
166                                if retry_after < 120 {
167                                    duration = Duration::from_secs(retry_after);
168                                }
169                            }
170                        }
171                    }
172
173                    too_many_requests_retries -= 1;
174                    debug!(
175                                "Too many requests: server responded with {:?}, {} retries left, pausing for {:?}",
176                                response, too_many_requests_retries, duration
177                            );
178
179                    sleep(duration).await;
180                    stats_updater.add_rate_limited_time(duration);
181                    continue;
182                }
183                return Err(response.error_for_status().unwrap_err().into());
184            }
185
186            let mut json = response.json::<serde_json::Value>().await?;
187            if json["error"].is_object() {
188                return match serde_json::from_value::<RpcErrorObject>(json["error"].clone()) {
189                    Ok(rpc_error_object) => {
190                        let data = match rpc_error_object.code {
191                                    custom_error::JSON_RPC_SERVER_ERROR_SEND_TRANSACTION_PREFLIGHT_FAILURE => {
192                                        match serde_json::from_value::<RpcSimulateTransactionResult>(json["error"]["data"].clone()) {
193                                            Ok(data) => RpcResponseErrorData::SendTransactionPreflightFailure(data),
194                                            Err(err) => {
195                                                debug!("Failed to deserialize RpcSimulateTransactionResult: {:?}", err);
196                                                RpcResponseErrorData::Empty
197                                            }
198                                        }
199                                    },
200                                    custom_error::JSON_RPC_SERVER_ERROR_NODE_UNHEALTHY => {
201                                        match serde_json::from_value::<custom_error::NodeUnhealthyErrorData>(json["error"]["data"].clone()) {
202                                            Ok(custom_error::NodeUnhealthyErrorData {num_slots_behind}) => RpcResponseErrorData::NodeUnhealthy {num_slots_behind},
203                                            Err(_err) => {
204                                                RpcResponseErrorData::Empty
205                                            }
206                                        }
207                                    },
208                                    _ => RpcResponseErrorData::Empty
209                                };
210
211                        Err(RpcError::RpcResponseError {
212                            code: rpc_error_object.code,
213                            message: rpc_error_object.message,
214                            data,
215                        }
216                        .into())
217                    }
218                    Err(err) => Err(RpcError::RpcRequestError(format!(
219                        "Failed to deserialize RPC error response: {} [{}]",
220                        serde_json::to_string(&json["error"]).unwrap(),
221                        err
222                    ))
223                    .into()),
224                };
225            }
226            return Ok(json["result"].take());
227        }
228    }
229
230    fn url(&self) -> String {
231        self.url.clone()
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238
239    #[tokio::test(flavor = "multi_thread")]
240    async fn http_sender_on_tokio_multi_thread() {
241        let http_sender = HttpSenderWithHeaders::new("http://localhost:1234".to_string(), None);
242        let _ = http_sender
243            .send(RpcRequest::GetVersion, serde_json::Value::Null)
244            .await;
245    }
246
247    #[tokio::test(flavor = "current_thread")]
248    async fn http_sender_on_tokio_current_thread() {
249        let http_sender = HttpSenderWithHeaders::new("http://localhost:1234".to_string(), None);
250        let _ = http_sender
251            .send(RpcRequest::GetVersion, serde_json::Value::Null)
252            .await;
253    }
254}