Skip to main content

xbp_cli/provider_support/
http.rs

1use reqwest::header::{
2    HeaderMap, HeaderName, HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_TYPE, USER_AGENT,
3};
4use reqwest::{Client, Method, StatusCode, Url};
5use serde::Serialize;
6use serde_json::Value;
7use std::borrow::Cow;
8use thiserror::Error;
9
10#[derive(Debug, Clone)]
11pub enum AuthStrategy {
12    None,
13    Bearer(String),
14    Header {
15        name: HeaderName,
16        value: HeaderValue,
17    },
18}
19
20#[derive(Debug, Clone)]
21pub struct RequestFactory {
22    client: Client,
23    base_url: Url,
24    auth: AuthStrategy,
25    default_headers: HeaderMap,
26}
27
28#[derive(Debug, Clone)]
29pub struct ResponseBytes {
30    pub content_type: Option<String>,
31    pub body: Vec<u8>,
32}
33
34#[derive(Debug, Error)]
35pub enum HttpError {
36    #[error("{message}")]
37    Request {
38        message: String,
39        status: Option<StatusCode>,
40        body: Option<String>,
41    },
42    #[error("failed to build request: {0}")]
43    Build(String),
44    #[error("failed to parse response JSON: {0}")]
45    Decode(String),
46}
47
48impl HttpError {
49    pub fn request(
50        message: impl Into<String>,
51        status: Option<StatusCode>,
52        body: Option<String>,
53    ) -> Self {
54        Self::Request {
55            message: message.into(),
56            status,
57            body,
58        }
59    }
60}
61
62impl RequestFactory {
63    pub fn new(base_url: impl AsRef<str>) -> Result<Self, HttpError> {
64        let client = Client::builder()
65            .user_agent("xbp")
66            .build()
67            .map_err(|error| HttpError::Build(error.to_string()))?;
68        let base_url =
69            Url::parse(base_url.as_ref()).map_err(|error| HttpError::Build(error.to_string()))?;
70        let mut default_headers = HeaderMap::new();
71        default_headers.insert(ACCEPT, HeaderValue::from_static("application/json"));
72        default_headers.insert(USER_AGENT, HeaderValue::from_static("xbp"));
73        Ok(Self {
74            client,
75            base_url,
76            auth: AuthStrategy::None,
77            default_headers,
78        })
79    }
80
81    pub fn with_auth(mut self, auth: AuthStrategy) -> Self {
82        self.auth = auth;
83        self
84    }
85
86    pub fn with_default_header(mut self, name: HeaderName, value: HeaderValue) -> Self {
87        self.default_headers.insert(name, value);
88        self
89    }
90
91    pub async fn get_json<T, Q>(&self, path: &str, query: Option<&Q>) -> Result<T, HttpError>
92    where
93        T: serde::de::DeserializeOwned,
94        Q: Serialize + ?Sized,
95    {
96        self.send_json(Method::GET, path, query, Option::<&Value>::None)
97            .await
98    }
99
100    pub async fn delete_json<T, Q>(&self, path: &str, query: Option<&Q>) -> Result<T, HttpError>
101    where
102        T: serde::de::DeserializeOwned,
103        Q: Serialize + ?Sized,
104    {
105        self.send_json(Method::DELETE, path, query, Option::<&Value>::None)
106            .await
107    }
108
109    pub async fn delete_json_with_body<T, Q, B>(
110        &self,
111        path: &str,
112        query: Option<&Q>,
113        body: &B,
114    ) -> Result<T, HttpError>
115    where
116        T: serde::de::DeserializeOwned,
117        Q: Serialize + ?Sized,
118        B: Serialize + ?Sized,
119    {
120        self.send_json(Method::DELETE, path, query, Some(body))
121            .await
122    }
123
124    pub async fn post_json<T, B>(&self, path: &str, body: &B) -> Result<T, HttpError>
125    where
126        T: serde::de::DeserializeOwned,
127        B: Serialize + ?Sized,
128    {
129        self.send_json(Method::POST, path, Option::<&Value>::None, Some(body))
130            .await
131    }
132
133    pub async fn put_json<T, B>(&self, path: &str, body: &B) -> Result<T, HttpError>
134    where
135        T: serde::de::DeserializeOwned,
136        B: Serialize + ?Sized,
137    {
138        self.send_json(Method::PUT, path, Option::<&Value>::None, Some(body))
139            .await
140    }
141
142    pub async fn patch_json<T, B>(&self, path: &str, body: &B) -> Result<T, HttpError>
143    where
144        T: serde::de::DeserializeOwned,
145        B: Serialize + ?Sized,
146    {
147        self.send_json(Method::PATCH, path, Option::<&Value>::None, Some(body))
148            .await
149    }
150
151    pub async fn post_bytes(
152        &self,
153        path: &str,
154        bytes: Vec<u8>,
155        content_type: &'static str,
156    ) -> Result<ResponseBytes, HttpError> {
157        let response = self
158            .request(Method::POST, path)?
159            .header(CONTENT_TYPE, content_type)
160            .body(bytes)
161            .send()
162            .await
163            .map_err(|error| HttpError::request(error.to_string(), None, None))?;
164        self.read_bytes_response(response).await
165    }
166
167    pub async fn get_bytes<Q>(
168        &self,
169        path: &str,
170        query: Option<&Q>,
171    ) -> Result<ResponseBytes, HttpError>
172    where
173        Q: Serialize + ?Sized,
174    {
175        let mut request = self.request(Method::GET, path)?;
176        if let Some(query) = query {
177            request = request.query(query);
178        }
179        let response = request
180            .send()
181            .await
182            .map_err(|error| HttpError::request(error.to_string(), None, None))?;
183        self.read_bytes_response(response).await
184    }
185
186    async fn send_json<T, Q, B>(
187        &self,
188        method: Method,
189        path: &str,
190        query: Option<&Q>,
191        body: Option<&B>,
192    ) -> Result<T, HttpError>
193    where
194        T: serde::de::DeserializeOwned,
195        Q: Serialize + ?Sized,
196        B: Serialize + ?Sized,
197    {
198        let mut request = self.request(method, path)?;
199        if let Some(query) = query {
200            request = request.query(query);
201        }
202        if let Some(body) = body {
203            request = request.json(body);
204        }
205        let response = request
206            .send()
207            .await
208            .map_err(|error| HttpError::request(error.to_string(), None, None))?;
209        let status = response.status();
210        let body = response
211            .text()
212            .await
213            .map_err(|error| HttpError::request(error.to_string(), Some(status), None))?;
214        if !status.is_success() {
215            let message = extract_cloudflare_error_message(&body)
216                .or_else(|| extract_github_error_message(&body))
217                .unwrap_or_else(|| format!("HTTP {}", status));
218            return Err(HttpError::request(message, Some(status), Some(body)));
219        }
220        serde_json::from_str(&body).map_err(|error| HttpError::Decode(error.to_string()))
221    }
222
223    fn request(&self, method: Method, path: &str) -> Result<reqwest::RequestBuilder, HttpError> {
224        let mut url = self
225            .base_url
226            .join(path)
227            .map_err(|error| HttpError::Build(error.to_string()))?;
228        if path.starts_with('/') {
229            let joined = format!("{}{}", self.base_url.as_str().trim_end_matches('/'), path);
230            url = Url::parse(&joined).map_err(|error| HttpError::Build(error.to_string()))?;
231        }
232
233        let mut builder = self.client.request(method, url);
234        builder = builder.headers(self.default_headers.clone());
235        match &self.auth {
236            AuthStrategy::None => {}
237            AuthStrategy::Bearer(token) => {
238                builder = builder.header(AUTHORIZATION, format!("Bearer {}", token));
239            }
240            AuthStrategy::Header { name, value } => {
241                builder = builder.header(name, value);
242            }
243        }
244        Ok(builder)
245    }
246
247    async fn read_bytes_response(
248        &self,
249        response: reqwest::Response,
250    ) -> Result<ResponseBytes, HttpError> {
251        let status = response.status();
252        let content_type = response
253            .headers()
254            .get(CONTENT_TYPE)
255            .and_then(|value| value.to_str().ok())
256            .map(str::to_string);
257        let bytes = response
258            .bytes()
259            .await
260            .map_err(|error| HttpError::request(error.to_string(), Some(status), None))?;
261        if !status.is_success() {
262            let body = String::from_utf8_lossy(&bytes).to_string();
263            let message = extract_cloudflare_error_message(&body)
264                .or_else(|| extract_github_error_message(&body))
265                .unwrap_or_else(|| format!("HTTP {}", status));
266            return Err(HttpError::request(message, Some(status), Some(body)));
267        }
268        Ok(ResponseBytes {
269            content_type,
270            body: bytes.to_vec(),
271        })
272    }
273}
274
275pub fn extract_github_error_message(body: &str) -> Option<String> {
276    let parsed = serde_json::from_str::<Value>(body.trim()).ok()?;
277    parsed
278        .get("message")
279        .and_then(Value::as_str)
280        .map(str::trim)
281        .filter(|value| !value.is_empty())
282        .map(ToOwned::to_owned)
283}
284
285pub fn extract_cloudflare_error_message(body: &str) -> Option<String> {
286    let parsed = serde_json::from_str::<Value>(body.trim()).ok()?;
287    let errors = parsed.get("errors")?.as_array()?;
288    let messages = errors
289        .iter()
290        .filter_map(|entry| {
291            let code = entry.get("code").and_then(Value::as_i64);
292            let message = entry.get("message").and_then(Value::as_str)?.trim();
293            if message.is_empty() {
294                return None;
295            }
296            Some(match code {
297                Some(code) => Cow::Owned(format!("{} ({})", message, code)),
298                None => Cow::Borrowed(message),
299            })
300        })
301        .collect::<Vec<_>>();
302    if messages.is_empty() {
303        None
304    } else {
305        Some(
306            messages
307                .into_iter()
308                .map(|value| value.into_owned())
309                .collect::<Vec<_>>()
310                .join("; "),
311        )
312    }
313}