1#[cfg(feature = "reqwest")]
5mod reqwest;
6#[cfg(feature = "reqwest-11")]
7mod reqwest_11;
8mod shim;
9
10pub struct HTTP<C: HttpClient> {
11 authorization_header: C::HeaderName,
12 app_name_header: C::HeaderName,
13 unleash_app_name_header: C::HeaderName,
14 unleash_sdk_header: C::HeaderName,
15 unleash_connection_id_header: C::HeaderName,
16 instance_id_header: C::HeaderName,
17 app_name: String,
18 sdk_version: &'static str,
19 instance_id: String,
20 connection_id: String,
25 authorization: Option<String>,
26 client: C,
27}
28
29use crate::version::get_sdk_version;
30use serde::{de::DeserializeOwned, Serialize};
31#[doc(inline)]
32pub use shim::HttpClient;
33
34impl<C> HTTP<C>
35where
36 C: HttpClient + Default,
37{
38 pub fn new(
40 app_name: String,
41 instance_id: String,
42 connection_id: String,
43 authorization: Option<String>,
44 ) -> Result<Self, C::Error> {
45 Ok(HTTP {
46 client: C::default(),
47 app_name,
48 sdk_version: get_sdk_version(),
49 connection_id,
50 instance_id,
51 authorization,
52 authorization_header: C::build_header("authorization")?,
53 app_name_header: C::build_header("appname")?,
54 unleash_app_name_header: C::build_header("unleash-appname")?,
55 unleash_sdk_header: C::build_header("unleash-sdk")?,
56 unleash_connection_id_header: C::build_header("unleash-connection-id")?,
57 instance_id_header: C::build_header("instance_id")?,
58 })
59 }
60
61 pub fn get(&self, uri: &str) -> C::RequestBuilder {
63 let request = self.client.get(uri);
64 self.attach_headers(request)
65 }
66
67 pub async fn get_json<T: DeserializeOwned>(
69 &self,
70 endpoint: &str,
71 interval: Option<u64>,
72 ) -> Result<T, C::Error> {
73 let mut request = self.get(endpoint);
74 if let Some(interval) = interval {
75 request = C::header(
76 request,
77 &C::build_header("unleash-interval")?,
78 &interval.to_string(),
79 );
80 }
81 C::get_json(request).await
82 }
83
84 pub fn post(&self, uri: &str) -> C::RequestBuilder {
86 let request = self.client.post(uri);
87 self.attach_headers(request)
88 }
89
90 pub async fn post_json<T: Serialize + Sync>(
93 &self,
94 endpoint: &str,
95 content: T,
96 interval: Option<u64>,
97 ) -> Result<bool, C::Error> {
98 let mut request = self.post(endpoint);
99 if let Some(interval) = interval {
100 request = C::header(
101 request,
102 &C::build_header("unleash-interval")?,
103 &interval.to_string(),
104 );
105 }
106 C::post_json(request, &content).await
107 }
108
109 fn attach_headers(&self, request: C::RequestBuilder) -> C::RequestBuilder {
110 let request = C::header(request, &self.app_name_header, self.app_name.as_str());
111 let request = C::header(
112 request,
113 &self.unleash_app_name_header,
114 self.app_name.as_str(),
115 );
116 let request = C::header(request, &self.unleash_sdk_header, self.sdk_version);
117 let request = C::header(
118 request,
119 &self.unleash_connection_id_header,
120 self.connection_id.as_str(),
121 );
122 let request = C::header(request, &self.instance_id_header, self.instance_id.as_str());
123 if let Some(auth) = &self.authorization {
124 C::header(request, &self.authorization_header.clone(), auth.as_str())
125 } else {
126 request
127 }
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134 use async_trait::async_trait;
135 use regex::Regex;
136 use serde_json::json;
137 use std::collections::HashMap;
138 use std::sync::{Arc, RwLock};
139 use uuid::Uuid;
140
141 #[derive(Clone, Default)]
142 struct MockHttpClient {
143 headers: Arc<RwLock<HashMap<String, String>>>,
144 }
145
146 #[async_trait]
147 impl HttpClient for MockHttpClient {
148 type Error = std::io::Error;
149 type HeaderName = String;
150 type RequestBuilder = Self;
151
152 fn build_header(name: &'static str) -> Result<Self::HeaderName, Self::Error> {
153 Ok(name.to_string())
154 }
155
156 fn header(builder: Self, key: &Self::HeaderName, value: &str) -> Self::RequestBuilder {
157 if let Ok(mut headers) = builder.headers.write() {
158 headers.insert(key.clone(), value.to_string());
159 }
160 builder
161 }
162
163 fn get(&self, _uri: &str) -> Self::RequestBuilder {
164 self.clone()
165 }
166
167 fn post(&self, _uri: &str) -> Self::RequestBuilder {
168 self.clone()
169 }
170
171 async fn get_json<T: DeserializeOwned>(
172 _req: Self::RequestBuilder,
173 ) -> Result<T, Self::Error> {
174 Ok(serde_json::from_value(json!({})).unwrap())
175 }
176
177 async fn post_json<T: Serialize + Sync>(
178 _req: Self::RequestBuilder,
179 _content: &T,
180 ) -> Result<bool, Self::Error> {
181 Ok(true)
182 }
183 }
184
185 #[tokio::test]
186 async fn test_specific_headers() {
187 let http_client = HTTP::<MockHttpClient>::new(
188 "my_app".to_string(),
189 "my_instance_id".to_string(),
190 "d512f8ec-d972-40a5-9a30-a0a6e85d93ac".to_string(),
191 Some("auth_token".to_string()),
192 )
193 .unwrap();
194
195 let _ = http_client
196 .get_json::<serde_json::Value>("http://example.com", Some(15))
197 .await;
198 let headers = &http_client.client.headers.read().unwrap();
199
200 assert_eq!(headers.get("unleash-appname").unwrap(), "my_app");
201 assert_eq!(headers.get("instance_id").unwrap(), "my_instance_id");
202 assert_eq!(
203 headers.get("unleash-connection-id").unwrap(),
204 "d512f8ec-d972-40a5-9a30-a0a6e85d93ac"
205 );
206 assert_eq!(headers.get("unleash-interval").unwrap(), "15");
207 assert_eq!(headers.get("authorization").unwrap(), "auth_token");
208
209 let version_regex = Regex::new(r"^unleash-client-rust:\d+\.\d+\.\d+$").unwrap();
210 let sdk_version = headers.get("unleash-sdk").unwrap();
211 assert!(
212 version_regex.is_match(sdk_version),
213 "Version output did not match expected format: {sdk_version}"
214 );
215
216 let connection_id = headers.get("unleash-connection-id").unwrap();
217 assert!(
218 Uuid::parse_str(connection_id).is_ok(),
219 "Connection ID is not a valid UUID"
220 );
221 }
222}