statsig_rust/networking/providers/
net_provider_reqwest.rs1use std::collections::HashMap;
2use std::time::Duration;
3
4use async_trait::async_trait;
5
6use crate::{
7 log_w,
8 networking::{
9 http_types::{HttpMethod, RequestArgs, Response},
10 NetworkProvider,
11 },
12};
13
14use reqwest::Method;
15
16const TAG: &str = "NetworkProviderReqwest";
17
18pub struct NetworkProviderReqwest {}
19
20#[async_trait]
21impl NetworkProvider for NetworkProviderReqwest {
22 async fn send(&self, method: &HttpMethod, args: &RequestArgs) -> Response {
23 if let Some(is_shutdown) = &args.is_shutdown {
24 if is_shutdown.load(std::sync::atomic::Ordering::SeqCst) {
25 return Response {
26 status_code: 0,
27 data: None,
28 error: Some("Request was shutdown".to_string()),
29 headers: None,
30 };
31 }
32 }
33
34 let request = self.build_request(method, args);
35
36 let error;
37 let mut status_code = 0;
38 let mut data = None;
39 let mut headers = None;
40
41 match request.send().await {
42 Ok(response) => {
43 status_code = response.status().as_u16();
44 headers = get_response_headers(&response);
45 data = response.bytes().await.ok().map(|bytes| bytes.to_vec());
46 error = None;
47 }
48 Err(e) => {
49 let error_message = get_error_message(e);
50 log_w!(TAG, "Request Error: {} {}", &args.url, error_message);
51 error = Some(error_message);
52 }
53 }
54
55 Response {
56 status_code,
57 data,
58 error,
59 headers,
60 }
61 }
62}
63
64impl NetworkProviderReqwest {
65 fn build_request(
66 &self,
67 method: &HttpMethod,
68 request_args: &RequestArgs,
69 ) -> reqwest::RequestBuilder {
70 let method_actual = match method {
71 HttpMethod::GET => Method::GET,
72 HttpMethod::POST => Method::POST,
73 };
74 let is_post = method_actual == Method::POST;
75
76 let client = reqwest::Client::new();
77 let mut request = client.request(method_actual, &request_args.url);
78
79 let timeout_duration = match request_args.timeout_ms > 0 {
80 true => Duration::from_millis(request_args.timeout_ms),
81 false => Duration::from_secs(10),
82 };
83 request = request.timeout(timeout_duration);
84
85 if let Some(headers) = &request_args.headers {
86 for (key, value) in headers {
87 request = request.header(key, value);
88 }
89 }
90
91 if let Some(params) = &request_args.query_params {
92 request = request.query(params);
93 }
94
95 if is_post {
96 let bytes = match &request_args.body {
97 Some(b) => b.clone(),
98 None => vec![],
99 };
100 let byte_len = bytes.len();
101
102 request = request.body(bytes);
103 request = request.header("Content-Length", byte_len.to_string());
104 }
105
106 request
107 }
108}
109
110fn get_error_message(error: reqwest::Error) -> String {
111 let mut error_message = error.to_string();
112
113 if let Some(url_error) = error.url() {
114 error_message.push_str(&format!(". URL: {}", url_error));
115 }
116
117 if let Some(status_error) = error.status() {
118 error_message.push_str(&format!(". Status: {}", status_error));
119 }
120
121 error_message
122}
123
124fn get_response_headers(response: &reqwest::Response) -> Option<HashMap<String, String>> {
125 let headers = response.headers();
126 if headers.is_empty() {
127 return None;
128 }
129
130 let mut headers_map = HashMap::new();
131 for (key, value) in headers {
132 headers_map.insert(key.to_string(), value.to_str().unwrap_or("").to_string());
133 }
134
135 Some(headers_map)
136}