Skip to main content

xbp_cli/provider_support/
http.rs

1use reqwest::header::{
2    HeaderMap, HeaderName, HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_TYPE, USER_AGENT,
3};
4use reqwest::{Client, Method, StatusCode, Url};
5use serde::Serialize;
6use serde_json::Value;
7use std::borrow::Cow;
8use thiserror::Error;
9
10#[derive(Debug, Clone)]
11pub enum AuthStrategy {
12    None,
13    Bearer(String),
14    Header {
15        name: HeaderName,
16        value: HeaderValue,
17    },
18}
19
20#[derive(Debug, Clone)]
21pub struct RequestFactory {
22    client: Client,
23    base_url: Url,
24    auth: AuthStrategy,
25    default_headers: HeaderMap,
26}
27
28#[derive(Debug, Clone)]
29pub struct ResponseBytes {
30    pub content_type: Option<String>,
31    pub body: Vec<u8>,
32}
33
34#[derive(Debug, Error)]
35pub enum HttpError {
36    #[error("{message}")]
37    Request {
38        message: String,
39        status: Option<StatusCode>,
40        body: Option<String>,
41    },
42    #[error("failed to build request: {0}")]
43    Build(String),
44    #[error("failed to parse response JSON: {0}")]
45    Decode(String),
46}
47
48impl HttpError {
49    pub fn request(
50        message: impl Into<String>,
51        status: Option<StatusCode>,
52        body: Option<String>,
53    ) -> Self {
54        Self::Request {
55            message: message.into(),
56            status,
57            body,
58        }
59    }
60}
61
62impl RequestFactory {
63    pub fn new(base_url: impl AsRef<str>) -> Result<Self, HttpError> {
64        let client = Client::builder()
65            .user_agent("xbp")
66            .build()
67            .map_err(|error| HttpError::Build(error.to_string()))?;
68        let base_url =
69            Url::parse(base_url.as_ref()).map_err(|error| HttpError::Build(error.to_string()))?;
70        let mut default_headers = HeaderMap::new();
71        default_headers.insert(ACCEPT, HeaderValue::from_static("application/json"));
72        default_headers.insert(USER_AGENT, HeaderValue::from_static("xbp"));
73        Ok(Self {
74            client,
75            base_url,
76            auth: AuthStrategy::None,
77            default_headers,
78        })
79    }
80
81    pub fn with_auth(mut self, auth: AuthStrategy) -> Self {
82        self.auth = auth;
83        self
84    }
85
86    pub fn with_default_header(mut self, name: HeaderName, value: HeaderValue) -> Self {
87        self.default_headers.insert(name, value);
88        self
89    }
90
91    pub async fn get_json<T, Q>(&self, path: &str, query: Option<&Q>) -> Result<T, HttpError>
92    where
93        T: serde::de::DeserializeOwned,
94        Q: Serialize + ?Sized,
95    {
96        self.send_json(Method::GET, path, query, Option::<&Value>::None)
97            .await
98    }
99
100    pub async fn delete_json<T, Q>(&self, path: &str, query: Option<&Q>) -> Result<T, HttpError>
101    where
102        T: serde::de::DeserializeOwned,
103        Q: Serialize + ?Sized,
104    {
105        self.send_json(Method::DELETE, path, query, Option::<&Value>::None)
106            .await
107    }
108
109    pub async fn delete_json_with_body<T, Q, B>(
110        &self,
111        path: &str,
112        query: Option<&Q>,
113        body: &B,
114    ) -> Result<T, HttpError>
115    where
116        T: serde::de::DeserializeOwned,
117        Q: Serialize + ?Sized,
118        B: Serialize + ?Sized,
119    {
120        self.send_json(Method::DELETE, path, query, Some(body))
121            .await
122    }
123
124    pub async fn post_json<T, B>(&self, path: &str, body: &B) -> Result<T, HttpError>
125    where
126        T: serde::de::DeserializeOwned,
127        B: Serialize + ?Sized,
128    {
129        self.send_json(Method::POST, path, Option::<&Value>::None, Some(body))
130            .await
131    }
132
133    pub async fn put_json<T, B>(&self, path: &str, body: &B) -> Result<T, HttpError>
134    where
135        T: serde::de::DeserializeOwned,
136        B: Serialize + ?Sized,
137    {
138        self.send_json(Method::PUT, path, Option::<&Value>::None, Some(body))
139            .await
140    }
141
142    pub async fn patch_json<T, B>(&self, path: &str, body: &B) -> Result<T, HttpError>
143    where
144        T: serde::de::DeserializeOwned,
145        B: Serialize + ?Sized,
146    {
147        self.send_json(Method::PATCH, path, Option::<&Value>::None, Some(body))
148            .await
149    }
150
151    pub async fn post_bytes(
152        &self,
153        path: &str,
154        bytes: Vec<u8>,
155        content_type: &'static str,
156    ) -> Result<ResponseBytes, HttpError> {
157        let response = self
158            .request(Method::POST, path)?
159            .header(CONTENT_TYPE, content_type)
160            .body(bytes)
161            .send()
162            .await
163            .map_err(map_send_error)?;
164        self.read_bytes_response(response).await
165    }
166
167    pub async fn get_bytes<Q>(
168        &self,
169        path: &str,
170        query: Option<&Q>,
171    ) -> Result<ResponseBytes, HttpError>
172    where
173        Q: Serialize + ?Sized,
174    {
175        let mut request = self.request(Method::GET, path)?;
176        if let Some(query) = query {
177            request = request.query(query);
178        }
179        let response = request.send().await.map_err(map_send_error)?;
180        self.read_bytes_response(response).await
181    }
182
183    async fn send_json<T, Q, B>(
184        &self,
185        method: Method,
186        path: &str,
187        query: Option<&Q>,
188        body: Option<&B>,
189    ) -> Result<T, HttpError>
190    where
191        T: serde::de::DeserializeOwned,
192        Q: Serialize + ?Sized,
193        B: Serialize + ?Sized,
194    {
195        let mut request = self.request(method, path)?;
196        if let Some(query) = query {
197            request = request.query(query);
198        }
199        if let Some(body) = body {
200            request = request.json(body);
201        }
202        let response = request.send().await.map_err(map_send_error)?;
203        let status = response.status();
204        let body = response
205            .text()
206            .await
207            .map_err(|error| HttpError::request(error.to_string(), Some(status), None))?;
208        if !status.is_success() {
209            let message = extract_cloudflare_error_message(&body)
210                .or_else(|| extract_github_error_message(&body))
211                .unwrap_or_else(|| format!("HTTP {}", status));
212            return Err(HttpError::request(message, Some(status), Some(body)));
213        }
214        serde_json::from_str(&body).map_err(|error| HttpError::Decode(error.to_string()))
215    }
216
217    fn request(&self, method: Method, path: &str) -> Result<reqwest::RequestBuilder, HttpError> {
218        let mut url = self
219            .base_url
220            .join(path)
221            .map_err(|error| HttpError::Build(error.to_string()))?;
222        if path.starts_with('/') {
223            let joined = format!("{}{}", self.base_url.as_str().trim_end_matches('/'), path);
224            url = Url::parse(&joined).map_err(|error| HttpError::Build(error.to_string()))?;
225        }
226
227        let mut builder = self.client.request(method, url);
228        builder = builder.headers(self.default_headers.clone());
229        match &self.auth {
230            AuthStrategy::None => {}
231            AuthStrategy::Bearer(token) => {
232                builder = builder.header(AUTHORIZATION, format!("Bearer {}", token));
233            }
234            AuthStrategy::Header { name, value } => {
235                builder = builder.header(name, value);
236            }
237        }
238        Ok(builder)
239    }
240
241    async fn read_bytes_response(
242        &self,
243        response: reqwest::Response,
244    ) -> Result<ResponseBytes, HttpError> {
245        let status = response.status();
246        let content_type = response
247            .headers()
248            .get(CONTENT_TYPE)
249            .and_then(|value| value.to_str().ok())
250            .map(str::to_string);
251        let bytes = response
252            .bytes()
253            .await
254            .map_err(|error| HttpError::request(error.to_string(), Some(status), None))?;
255        if !status.is_success() {
256            let body = String::from_utf8_lossy(&bytes).to_string();
257            let message = extract_cloudflare_error_message(&body)
258                .or_else(|| extract_github_error_message(&body))
259                .unwrap_or_else(|| format!("HTTP {}", status));
260            return Err(HttpError::request(message, Some(status), Some(body)));
261        }
262        Ok(ResponseBytes {
263            content_type,
264            body: bytes.to_vec(),
265        })
266    }
267}
268
269fn map_send_error(error: reqwest::Error) -> HttpError {
270    if error.is_builder() {
271        return HttpError::Build(format!(
272            "{}. Check URL construction and auth/header values for stray whitespace or invalid characters.",
273            error
274        ));
275    }
276
277    HttpError::request(error.to_string(), None, None)
278}
279
280pub fn extract_github_error_message(body: &str) -> Option<String> {
281    let parsed = serde_json::from_str::<Value>(body.trim()).ok()?;
282    parsed
283        .get("message")
284        .and_then(Value::as_str)
285        .map(str::trim)
286        .filter(|value| !value.is_empty())
287        .map(ToOwned::to_owned)
288}
289
290pub fn extract_cloudflare_error_message(body: &str) -> Option<String> {
291    let parsed = serde_json::from_str::<Value>(body.trim()).ok()?;
292    let errors = parsed.get("errors")?.as_array()?;
293    let messages = errors
294        .iter()
295        .filter_map(|entry| {
296            let code = entry.get("code").and_then(Value::as_i64);
297            let message = entry.get("message").and_then(Value::as_str)?.trim();
298            if message.is_empty() {
299                return None;
300            }
301            Some(match code {
302                Some(code) => Cow::Owned(format!("{} ({})", message, code)),
303                None => Cow::Borrowed(message),
304            })
305        })
306        .collect::<Vec<_>>();
307    if messages.is_empty() {
308        None
309    } else {
310        Some(
311            messages
312                .into_iter()
313                .map(|value| value.into_owned())
314                .collect::<Vec<_>>()
315                .join("; "),
316        )
317    }
318}