statsig_rust/networking/providers/
net_provider_reqwest.rs

1use std::collections::HashMap;
2use std::io::{BufReader, Seek, SeekFrom, Write};
3use std::time::Duration;
4
5use async_trait::async_trait;
6
7use crate::{
8    log_e, log_w,
9    networking::{
10        http_types::{HttpMethod, RequestArgs, Response, ResponseData},
11        NetworkProvider,
12    },
13    StatsigErr,
14};
15
16use crate::networking::proxy_config::ProxyConfig;
17use reqwest::Method;
18
19const TAG: &str = "NetworkProviderReqwest";
20
21pub struct NetworkProviderReqwest {
22    has_file_write_access: bool,
23}
24
25impl NetworkProviderReqwest {
26    pub fn new() -> Self {
27        let has_file_write_access = match tempfile::tempfile() {
28            Ok(_) => true,
29            Err(_) => false,
30        };
31
32        Self {
33            has_file_write_access,
34        }
35    }
36}
37
38#[async_trait]
39impl NetworkProvider for NetworkProviderReqwest {
40    async fn send(&self, method: &HttpMethod, args: &RequestArgs) -> Response {
41        if let Some(is_shutdown) = &args.is_shutdown {
42            if is_shutdown.load(std::sync::atomic::Ordering::SeqCst) {
43                return Response {
44                    status_code: None,
45                    data: None,
46                    error: Some("Request was shutdown".to_string()),
47                    headers: None,
48                };
49            }
50        }
51
52        let request = self.build_request(method, args);
53
54        let mut error = None;
55        let mut status_code = None;
56        let mut data = None;
57        let mut headers = None;
58
59        match request.send().await {
60            Ok(response) => {
61                status_code = Some(response.status().as_u16());
62                headers = get_response_headers(&response);
63
64                let data_result =
65                    if !self.has_file_write_access || args.disable_file_streaming == Some(true) {
66                        Self::write_response_to_in_memory_buffer(response).await
67                    } else {
68                        Self::write_response_to_temp_file(response).await
69                    };
70
71                match data_result {
72                    Ok(response_data) => data = Some(response_data),
73                    Err(e) => {
74                        error = Some(e.to_string());
75                    }
76                }
77            }
78            Err(e) => {
79                let error_message = get_error_message(e);
80                error = Some(error_message);
81            }
82        }
83
84        Response {
85            status_code,
86            data,
87            error,
88            headers,
89        }
90    }
91}
92
93impl NetworkProviderReqwest {
94    fn build_request(
95        &self,
96        method: &HttpMethod,
97        request_args: &RequestArgs,
98    ) -> reqwest::RequestBuilder {
99        let method_actual = match method {
100            HttpMethod::GET => Method::GET,
101            HttpMethod::POST => Method::POST,
102        };
103        let is_post = method_actual == Method::POST;
104
105        let mut client_builder = reqwest::Client::builder();
106
107        // configure proxy if available
108        if let Some(proxy_config) = request_args.proxy_config.as_ref() {
109            client_builder = Self::configure_proxy(client_builder, proxy_config);
110        }
111
112        let client = client_builder.build().unwrap_or_else(|e| {
113            log_e!(TAG, "Failed to build reqwest client with proxy config: {}. Falling back to default client.", e);
114            reqwest::Client::new()
115        });
116
117        let mut request = client.request(method_actual, &request_args.url);
118
119        let timeout_duration = match request_args.timeout_ms > 0 {
120            true => Duration::from_millis(request_args.timeout_ms),
121            false => Duration::from_secs(10),
122        };
123        request = request.timeout(timeout_duration);
124
125        if let Some(headers) = &request_args.headers {
126            for (key, value) in headers {
127                request = request.header(key, value);
128            }
129        }
130
131        if let Some(params) = &request_args.query_params {
132            request = request.query(params);
133        }
134
135        if is_post {
136            let bytes = match &request_args.body {
137                Some(b) => b.clone(),
138                None => vec![],
139            };
140            let byte_len = bytes.len();
141
142            request = request.body(bytes);
143            request = request.header("Content-Length", byte_len.to_string());
144        }
145
146        request
147    }
148
149    fn configure_proxy(
150        client_builder: reqwest::ClientBuilder,
151        proxy_config: &ProxyConfig,
152    ) -> reqwest::ClientBuilder {
153        let (Some(host), Some(port)) = (&proxy_config.proxy_host, &proxy_config.proxy_port) else {
154            return client_builder;
155        };
156
157        let proxy_url = format!(
158            "{}://{}:{}",
159            proxy_config.proxy_protocol.as_deref().unwrap_or("http"),
160            host,
161            port
162        );
163
164        let Ok(proxy) = reqwest::Proxy::all(&proxy_url) else {
165            log_w!(TAG, "Failed to create proxy for URL: {}", proxy_url);
166            return client_builder;
167        };
168
169        let Some(auth) = &proxy_config.proxy_auth else {
170            return client_builder.proxy(proxy);
171        };
172
173        let Some((username, password)) = auth.split_once(':') else {
174            log_w!(
175                TAG,
176                "Invalid proxy auth format. Expected 'username:password'"
177            );
178            return client_builder.proxy(proxy);
179        };
180
181        client_builder.proxy(proxy.basic_auth(username, password))
182    }
183
184    async fn write_response_to_temp_file(
185        response: reqwest::Response,
186    ) -> Result<ResponseData, StatsigErr> {
187        let mut response = response;
188        let mut temp_file = tempfile::spooled_tempfile(1024 * 1024 * 2); // 2MB
189
190        while let Some(item) = response
191            .chunk()
192            .await
193            .map_err(|e| StatsigErr::FileError(e.to_string()))?
194        {
195            temp_file
196                .write_all(&item)
197                .map_err(|e| StatsigErr::FileError(e.to_string()))?;
198        }
199
200        temp_file
201            .seek(SeekFrom::Start(0))
202            .map_err(|e| StatsigErr::FileError(e.to_string()))?;
203
204        let reader = BufReader::new(temp_file);
205        Ok(ResponseData::from_stream(Box::new(reader)))
206    }
207
208    async fn write_response_to_in_memory_buffer(
209        response: reqwest::Response,
210    ) -> Result<ResponseData, StatsigErr> {
211        let bytes = response
212            .bytes()
213            .await
214            .map_err(|e| StatsigErr::SerializationError(e.to_string()))?;
215
216        Ok(ResponseData::from_bytes(bytes.to_vec()))
217    }
218}
219
220fn get_error_message(error: reqwest::Error) -> String {
221    let mut error_message = error.to_string();
222
223    if let Some(url_error) = error.url() {
224        error_message.push_str(&format!(". URL: {}", url_error));
225    }
226
227    if let Some(status_error) = error.status() {
228        error_message.push_str(&format!(". Status: {}", status_error));
229    }
230
231    error_message
232}
233
234fn get_response_headers(response: &reqwest::Response) -> Option<HashMap<String, String>> {
235    let headers = response.headers();
236    if headers.is_empty() {
237        return None;
238    }
239
240    let mut headers_map = HashMap::new();
241    for (key, value) in headers {
242        headers_map.insert(key.to_string(), value.to_str().unwrap_or("").to_string());
243    }
244
245    Some(headers_map)
246}