Skip to main content

statsig_rust/networking/providers/
net_provider_reqwest.rs

1use std::collections::HashMap;
2use std::io::{BufReader, Seek, SeekFrom, Write};
3use std::time::Duration;
4
5use async_trait::async_trait;
6
7use crate::log_d;
8use crate::{
9    log_e, log_w,
10    networking::{
11        http_types::{HttpMethod, RequestArgs, Response, ResponseData},
12        NetworkProvider,
13    },
14    StatsigErr,
15};
16
17use crate::networking::proxy_config::ProxyConfig;
18use reqwest::Method;
19
20const TAG: &str = "NetworkProviderReqwest";
21
22pub struct NetworkProviderReqwest {
23    has_file_write_access: bool,
24}
25
26impl NetworkProviderReqwest {
27    pub fn new() -> Self {
28        Self {
29            has_file_write_access: tempfile::tempfile().is_ok(),
30        }
31    }
32}
33
34#[async_trait]
35impl NetworkProvider for NetworkProviderReqwest {
36    async fn send(&self, method: &HttpMethod, args: &RequestArgs) -> Response {
37        if let Some(is_shutdown) = &args.is_shutdown {
38            if is_shutdown.load(std::sync::atomic::Ordering::SeqCst) {
39                return Response {
40                    status_code: None,
41                    data: None,
42                    error: Some("Request was shutdown".to_string()),
43                };
44            }
45        }
46
47        let request = self.build_request(method, args);
48
49        let mut error = None;
50        let mut status_code = None;
51        let mut data = None;
52
53        match request.send().await {
54            Ok(response) => {
55                status_code = Some(response.status().as_u16());
56
57                let data_result =
58                    if !self.has_file_write_access || args.disable_file_streaming == Some(true) {
59                        Self::write_response_to_in_memory_buffer(response).await
60                    } else {
61                        Self::write_response_to_temp_file(response).await
62                    };
63
64                match data_result {
65                    Ok(response_data) => data = Some(response_data),
66                    Err(e) => {
67                        error = Some(e.to_string());
68                    }
69                }
70            }
71            Err(e) => {
72                let error_message = get_error_message(e);
73                error = Some(error_message);
74            }
75        }
76
77        Response {
78            status_code,
79            data,
80            error,
81        }
82    }
83}
84
85impl NetworkProviderReqwest {
86    fn build_request(
87        &self,
88        method: &HttpMethod,
89        request_args: &RequestArgs,
90    ) -> reqwest::RequestBuilder {
91        let method_actual = match method {
92            HttpMethod::GET => Method::GET,
93            HttpMethod::POST => Method::POST,
94        };
95        let is_post = method_actual == Method::POST;
96
97        let mut client_builder = reqwest::Client::builder();
98
99        // configure proxy if available
100        if let Some(proxy_config) = request_args.proxy_config.as_ref() {
101            client_builder = Self::configure_proxy(client_builder, proxy_config);
102        }
103
104        let client = client_builder.build().unwrap_or_else(|e| {
105            log_e!(TAG, "Failed to build reqwest client with proxy config: {}. Falling back to default client.", e);
106            reqwest::Client::new()
107        });
108
109        let mut request = client.request(method_actual, &request_args.url);
110
111        let timeout_duration = match request_args.timeout_ms > 0 {
112            true => Duration::from_millis(request_args.timeout_ms),
113            false => Duration::from_secs(10),
114        };
115        request = request.timeout(timeout_duration);
116
117        if let Some(headers) = &request_args.headers {
118            for (key, value) in headers {
119                request = request.header(key, value);
120            }
121        }
122
123        if let Some(params) = &request_args.query_params {
124            request = request.query(params);
125        }
126
127        if is_post {
128            let bytes = match &request_args.body {
129                Some(b) => b.clone(),
130                None => vec![],
131            };
132            let byte_len = bytes.len();
133
134            request = request.body(bytes);
135            request = request.header("Content-Length", byte_len.to_string());
136        }
137
138        request
139    }
140
141    fn configure_proxy(
142        client_builder: reqwest::ClientBuilder,
143        proxy_config: &ProxyConfig,
144    ) -> reqwest::ClientBuilder {
145        let (Some(host), Some(port)) = (&proxy_config.proxy_host, &proxy_config.proxy_port) else {
146            return client_builder;
147        };
148
149        let proxy_url = format!(
150            "{}://{}:{}",
151            proxy_config.proxy_protocol.as_deref().unwrap_or("http"),
152            host,
153            port
154        );
155
156        let Ok(proxy) = reqwest::Proxy::all(&proxy_url) else {
157            log_w!(TAG, "Failed to create proxy for URL: {}", proxy_url);
158            return client_builder;
159        };
160
161        let Some(auth) = &proxy_config.proxy_auth else {
162            return client_builder.proxy(proxy);
163        };
164
165        let Some((username, password)) = auth.split_once(':') else {
166            log_w!(
167                TAG,
168                "Invalid proxy auth format. Expected 'username:password'"
169            );
170            return client_builder.proxy(proxy);
171        };
172
173        client_builder.proxy(proxy.basic_auth(username, password))
174    }
175
176    async fn write_response_to_temp_file(
177        response: reqwest::Response,
178    ) -> Result<ResponseData, StatsigErr> {
179        let headers = get_response_headers(&response);
180        let mut response = response;
181        let mut temp_file = tempfile::spooled_tempfile(1024 * 1024 * 2); // 2MB
182
183        let mut total_bytes = 0;
184        while let Some(item) = response
185            .chunk()
186            .await
187            .map_err(|e| StatsigErr::FileError(e.to_string()))?
188        {
189            total_bytes += item.len();
190            temp_file
191                .write_all(&item)
192                .map_err(|e| StatsigErr::FileError(e.to_string()))?;
193        }
194
195        temp_file
196            .seek(SeekFrom::Start(0))
197            .map_err(|e| StatsigErr::FileError(e.to_string()))?;
198
199        let reader = BufReader::new(temp_file);
200
201        log_d!(TAG, "Wrote {} bytes to spooled temp file", total_bytes);
202
203        Ok(ResponseData::from_stream_with_headers(
204            Box::new(reader),
205            headers,
206        ))
207    }
208
209    async fn write_response_to_in_memory_buffer(
210        response: reqwest::Response,
211    ) -> Result<ResponseData, StatsigErr> {
212        let headers = get_response_headers(&response);
213        let bytes = response
214            .bytes()
215            .await
216            .map_err(|e| StatsigErr::SerializationError(e.to_string()))?;
217
218        log_d!(TAG, "Wrote {} bytes to in-memory buffer", bytes.len());
219
220        Ok(ResponseData::from_bytes_with_headers(
221            bytes.to_vec(),
222            headers,
223        ))
224    }
225}
226
227fn get_error_message(error: reqwest::Error) -> String {
228    let mut error_message = error.to_string();
229
230    if let Some(url_error) = error.url() {
231        error_message.push_str(&format!(". URL: {}", url_error));
232    }
233
234    if let Some(status_error) = error.status() {
235        error_message.push_str(&format!(". Status: {}", status_error));
236    }
237
238    error_message
239}
240
241fn get_response_headers(response: &reqwest::Response) -> Option<HashMap<String, String>> {
242    let headers = response.headers();
243    if headers.is_empty() {
244        return None;
245    }
246
247    let mut headers_map = HashMap::new();
248    for (key, value) in headers {
249        headers_map.insert(key.to_string(), value.to_str().unwrap_or("").to_string());
250    }
251
252    Some(headers_map)
253}