1use reqwest::header::{
2 HeaderMap, HeaderName, HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_TYPE, USER_AGENT,
3};
4use reqwest::{Client, Method, StatusCode, Url};
5use serde::Serialize;
6use serde_json::Value;
7use std::borrow::Cow;
8use thiserror::Error;
9
10#[derive(Debug, Clone)]
11pub enum AuthStrategy {
12 None,
13 Bearer(String),
14 Header {
15 name: HeaderName,
16 value: HeaderValue,
17 },
18}
19
20#[derive(Debug, Clone)]
21pub struct RequestFactory {
22 client: Client,
23 base_url: Url,
24 auth: AuthStrategy,
25 default_headers: HeaderMap,
26}
27
28#[derive(Debug, Clone)]
29pub struct ResponseBytes {
30 pub content_type: Option<String>,
31 pub body: Vec<u8>,
32}
33
34#[derive(Debug, Error)]
35pub enum HttpError {
36 #[error("{message}")]
37 Request {
38 message: String,
39 status: Option<StatusCode>,
40 body: Option<String>,
41 },
42 #[error("failed to build request: {0}")]
43 Build(String),
44 #[error("failed to parse response JSON: {0}")]
45 Decode(String),
46}
47
48impl HttpError {
49 pub fn request(
50 message: impl Into<String>,
51 status: Option<StatusCode>,
52 body: Option<String>,
53 ) -> Self {
54 Self::Request {
55 message: message.into(),
56 status,
57 body,
58 }
59 }
60}
61
62impl RequestFactory {
63 pub fn new(base_url: impl AsRef<str>) -> Result<Self, HttpError> {
64 let client = Client::builder()
65 .user_agent("xbp")
66 .build()
67 .map_err(|error| HttpError::Build(error.to_string()))?;
68 let base_url =
69 Url::parse(base_url.as_ref()).map_err(|error| HttpError::Build(error.to_string()))?;
70 let mut default_headers = HeaderMap::new();
71 default_headers.insert(ACCEPT, HeaderValue::from_static("application/json"));
72 default_headers.insert(USER_AGENT, HeaderValue::from_static("xbp"));
73 Ok(Self {
74 client,
75 base_url,
76 auth: AuthStrategy::None,
77 default_headers,
78 })
79 }
80
81 pub fn with_auth(mut self, auth: AuthStrategy) -> Self {
82 self.auth = auth;
83 self
84 }
85
86 pub fn with_default_header(mut self, name: HeaderName, value: HeaderValue) -> Self {
87 self.default_headers.insert(name, value);
88 self
89 }
90
91 pub async fn get_json<T, Q>(&self, path: &str, query: Option<&Q>) -> Result<T, HttpError>
92 where
93 T: serde::de::DeserializeOwned,
94 Q: Serialize + ?Sized,
95 {
96 self.send_json(Method::GET, path, query, Option::<&Value>::None)
97 .await
98 }
99
100 pub async fn delete_json<T, Q>(&self, path: &str, query: Option<&Q>) -> Result<T, HttpError>
101 where
102 T: serde::de::DeserializeOwned,
103 Q: Serialize + ?Sized,
104 {
105 self.send_json(Method::DELETE, path, query, Option::<&Value>::None)
106 .await
107 }
108
109 pub async fn delete_json_with_body<T, Q, B>(
110 &self,
111 path: &str,
112 query: Option<&Q>,
113 body: &B,
114 ) -> Result<T, HttpError>
115 where
116 T: serde::de::DeserializeOwned,
117 Q: Serialize + ?Sized,
118 B: Serialize + ?Sized,
119 {
120 self.send_json(Method::DELETE, path, query, Some(body))
121 .await
122 }
123
124 pub async fn post_json<T, B>(&self, path: &str, body: &B) -> Result<T, HttpError>
125 where
126 T: serde::de::DeserializeOwned,
127 B: Serialize + ?Sized,
128 {
129 self.send_json(Method::POST, path, Option::<&Value>::None, Some(body))
130 .await
131 }
132
133 pub async fn put_json<T, B>(&self, path: &str, body: &B) -> Result<T, HttpError>
134 where
135 T: serde::de::DeserializeOwned,
136 B: Serialize + ?Sized,
137 {
138 self.send_json(Method::PUT, path, Option::<&Value>::None, Some(body))
139 .await
140 }
141
142 pub async fn patch_json<T, B>(&self, path: &str, body: &B) -> Result<T, HttpError>
143 where
144 T: serde::de::DeserializeOwned,
145 B: Serialize + ?Sized,
146 {
147 self.send_json(Method::PATCH, path, Option::<&Value>::None, Some(body))
148 .await
149 }
150
151 pub async fn post_bytes(
152 &self,
153 path: &str,
154 bytes: Vec<u8>,
155 content_type: &'static str,
156 ) -> Result<ResponseBytes, HttpError> {
157 let response = self
158 .request(Method::POST, path)?
159 .header(CONTENT_TYPE, content_type)
160 .body(bytes)
161 .send()
162 .await
163 .map_err(map_send_error)?;
164 self.read_bytes_response(response).await
165 }
166
167 pub async fn get_bytes<Q>(
168 &self,
169 path: &str,
170 query: Option<&Q>,
171 ) -> Result<ResponseBytes, HttpError>
172 where
173 Q: Serialize + ?Sized,
174 {
175 let mut request = self.request(Method::GET, path)?;
176 if let Some(query) = query {
177 request = request.query(query);
178 }
179 let response = request.send().await.map_err(map_send_error)?;
180 self.read_bytes_response(response).await
181 }
182
183 async fn send_json<T, Q, B>(
184 &self,
185 method: Method,
186 path: &str,
187 query: Option<&Q>,
188 body: Option<&B>,
189 ) -> Result<T, HttpError>
190 where
191 T: serde::de::DeserializeOwned,
192 Q: Serialize + ?Sized,
193 B: Serialize + ?Sized,
194 {
195 let mut request = self.request(method, path)?;
196 if let Some(query) = query {
197 request = request.query(query);
198 }
199 if let Some(body) = body {
200 request = request.json(body);
201 }
202 let response = request.send().await.map_err(map_send_error)?;
203 let status = response.status();
204 let body = response
205 .text()
206 .await
207 .map_err(|error| HttpError::request(error.to_string(), Some(status), None))?;
208 if !status.is_success() {
209 let message = extract_cloudflare_error_message(&body)
210 .or_else(|| extract_github_error_message(&body))
211 .unwrap_or_else(|| format!("HTTP {}", status));
212 return Err(HttpError::request(message, Some(status), Some(body)));
213 }
214 serde_json::from_str(&body).map_err(|error| HttpError::Decode(error.to_string()))
215 }
216
217 fn request(&self, method: Method, path: &str) -> Result<reqwest::RequestBuilder, HttpError> {
218 let mut url = self
219 .base_url
220 .join(path)
221 .map_err(|error| HttpError::Build(error.to_string()))?;
222 if path.starts_with('/') {
223 let joined = format!("{}{}", self.base_url.as_str().trim_end_matches('/'), path);
224 url = Url::parse(&joined).map_err(|error| HttpError::Build(error.to_string()))?;
225 }
226
227 let mut builder = self.client.request(method, url);
228 builder = builder.headers(self.default_headers.clone());
229 match &self.auth {
230 AuthStrategy::None => {}
231 AuthStrategy::Bearer(token) => {
232 builder = builder.header(AUTHORIZATION, format!("Bearer {}", token));
233 }
234 AuthStrategy::Header { name, value } => {
235 builder = builder.header(name, value);
236 }
237 }
238 Ok(builder)
239 }
240
241 async fn read_bytes_response(
242 &self,
243 response: reqwest::Response,
244 ) -> Result<ResponseBytes, HttpError> {
245 let status = response.status();
246 let content_type = response
247 .headers()
248 .get(CONTENT_TYPE)
249 .and_then(|value| value.to_str().ok())
250 .map(str::to_string);
251 let bytes = response
252 .bytes()
253 .await
254 .map_err(|error| HttpError::request(error.to_string(), Some(status), None))?;
255 if !status.is_success() {
256 let body = String::from_utf8_lossy(&bytes).to_string();
257 let message = extract_cloudflare_error_message(&body)
258 .or_else(|| extract_github_error_message(&body))
259 .unwrap_or_else(|| format!("HTTP {}", status));
260 return Err(HttpError::request(message, Some(status), Some(body)));
261 }
262 Ok(ResponseBytes {
263 content_type,
264 body: bytes.to_vec(),
265 })
266 }
267}
268
269fn map_send_error(error: reqwest::Error) -> HttpError {
270 if error.is_builder() {
271 return HttpError::Build(format!(
272 "{}. Check URL construction and auth/header values for stray whitespace or invalid characters.",
273 error
274 ));
275 }
276
277 HttpError::request(error.to_string(), None, None)
278}
279
280pub fn extract_github_error_message(body: &str) -> Option<String> {
281 let parsed = serde_json::from_str::<Value>(body.trim()).ok()?;
282 parsed
283 .get("message")
284 .and_then(Value::as_str)
285 .map(str::trim)
286 .filter(|value| !value.is_empty())
287 .map(ToOwned::to_owned)
288}
289
290pub fn extract_cloudflare_error_message(body: &str) -> Option<String> {
291 let parsed = serde_json::from_str::<Value>(body.trim()).ok()?;
292 let errors = parsed.get("errors")?.as_array()?;
293 let messages = errors
294 .iter()
295 .filter_map(|entry| {
296 let code = entry.get("code").and_then(Value::as_i64);
297 let message = entry.get("message").and_then(Value::as_str)?.trim();
298 if message.is_empty() {
299 return None;
300 }
301 Some(match code {
302 Some(code) => Cow::Owned(format!("{} ({})", message, code)),
303 None => Cow::Borrowed(message),
304 })
305 })
306 .collect::<Vec<_>>();
307 if messages.is_empty() {
308 None
309 } else {
310 Some(
311 messages
312 .into_iter()
313 .map(|value| value.into_owned())
314 .collect::<Vec<_>>()
315 .join("; "),
316 )
317 }
318}