Skip to main content

toolcraft_request/
client.rs

1use futures_util::StreamExt;
2use reqwest::{Client, multipart};
3use url::Url;
4
5use crate::{
6    error::{Error, Result},
7    header_map::HeaderMap,
8    response::{ByteStream, Response},
9};
10
11/// An HTTP request builder and executor with base URL and default headers.
12#[derive(Debug)]
13pub struct Request {
14    client: Client,
15    base_url: Option<Url>,
16    default_headers: HeaderMap,
17}
18
19impl Request {
20    /// Create a new Request client.
21    pub fn new() -> Result<Self> {
22        let client = Client::builder()
23            .build()
24            .map_err(|e| Error::ErrorMessage(e.to_string().into()))?;
25        Ok(Request {
26            client,
27            base_url: None,
28            default_headers: HeaderMap::new(),
29        })
30    }
31
32    pub fn with_timeout(timeout_sec: u64) -> Result<Self> {
33        let client = Client::builder()
34            .timeout(std::time::Duration::from_secs(timeout_sec))
35            .build()
36            .map_err(|e| Error::ErrorMessage(e.to_string().into()))?;
37        Ok(Request {
38            client,
39            base_url: None,
40            default_headers: HeaderMap::new(),
41        })
42    }
43
44    /// Set the base URL for all requests.
45    pub fn set_base_url(&mut self, base_url: &str) -> Result<()> {
46        let mut url_str = base_url.to_string();
47        if !url_str.ends_with('/') {
48            url_str.push('/');
49        }
50        let url = Url::parse(&url_str)?;
51        self.base_url = Some(url);
52        Ok(())
53    }
54
55    /// Set default headers to be applied on all requests.
56    pub fn set_default_headers(&mut self, headers: HeaderMap) {
57        self.default_headers = headers;
58    }
59
60    /// Send a GET request.
61    pub async fn get(
62        &self,
63        endpoint: &str,
64        query: Option<Vec<(String, String)>>,
65        headers: Option<HeaderMap>,
66    ) -> Result<Response> {
67        let url = self.build_url(endpoint, query)?;
68        let mut request = self.client.get(url.as_str());
69
70        let mut combined_headers = self.default_headers.clone();
71        if let Some(custom_headers) = headers {
72            combined_headers.merge(custom_headers);
73        }
74        request = request.headers(combined_headers.inner().clone());
75
76        let response = request.send().await?;
77        Ok(response.into())
78    }
79
80    /// Send a POST request with JSON body.
81    pub async fn post(
82        &self,
83        endpoint: &str,
84        body: &serde_json::Value,
85        headers: Option<HeaderMap>,
86    ) -> Result<Response> {
87        let url = self.build_url(endpoint, None)?;
88        let mut request = self.client.post(url).json(body);
89
90        let mut combined_headers = self.default_headers.clone();
91        if let Some(custom_headers) = headers {
92            combined_headers.merge(custom_headers);
93        }
94        request = request.headers(combined_headers.inner().clone());
95
96        let response = request.send().await?;
97        Ok(response.into())
98    }
99
100    /// Send a PUT request with JSON body.
101    pub async fn put(
102        &self,
103        endpoint: &str,
104        body: &serde_json::Value,
105        headers: Option<HeaderMap>,
106    ) -> Result<Response> {
107        let url = self.build_url(endpoint, None)?;
108        let mut request = self.client.put(url).json(body);
109
110        let mut combined_headers = self.default_headers.clone();
111        if let Some(custom_headers) = headers {
112            combined_headers.merge(custom_headers);
113        }
114        request = request.headers(combined_headers.inner().clone());
115
116        let response = request.send().await?;
117        Ok(response.into())
118    }
119
120    /// Send a DELETE request.
121    pub async fn delete(&self, endpoint: &str, headers: Option<HeaderMap>) -> Result<Response> {
122        let url = self.build_url(endpoint, None)?;
123        let mut request = self.client.delete(url);
124
125        let mut combined_headers = self.default_headers.clone();
126        if let Some(custom_headers) = headers {
127            combined_headers.merge(custom_headers);
128        }
129        request = request.headers(combined_headers.inner().clone());
130
131        let response = request.send().await?;
132        Ok(response.into())
133    }
134
135    /// Send a HEAD request.
136    pub async fn head(&self, endpoint: &str, headers: Option<HeaderMap>) -> Result<Response> {
137        let url = self.build_url(endpoint, None)?;
138        let mut request = self.client.head(url);
139
140        let mut combined_headers = self.default_headers.clone();
141        if let Some(custom_headers) = headers {
142            combined_headers.merge(custom_headers);
143        }
144        request = request.headers(combined_headers.inner().clone());
145
146        let response = request.send().await?;
147        Ok(response.into())
148    }
149
150    /// Send a POST request with multipart/form-data.
151    ///
152    /// # Arguments
153    /// * `endpoint` - The URL endpoint
154    /// * `form_fields` - Vector of form fields (text or file)
155    /// * `headers` - Optional custom headers
156    ///
157    /// # Important
158    /// The `Content-Type` header will be automatically removed from default and custom headers
159    /// to allow reqwest to set the correct `multipart/form-data` with boundary.
160    ///
161    /// # Example
162    /// ```no_run
163    /// use toolcraft_request::{FormField, Request};
164    ///
165    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
166    /// let client = Request::new()?;
167    /// let fields = vec![
168    ///     FormField::text("name", "John"),
169    ///     FormField::file("avatar", "/path/to/image.jpg").await?,
170    /// ];
171    /// let response = client.post_form("/upload", fields, None).await?;
172    /// # Ok(())
173    /// # }
174    /// ```
175    pub async fn post_form(
176        &self,
177        endpoint: &str,
178        form_fields: Vec<FormField>,
179        headers: Option<HeaderMap>,
180    ) -> Result<Response> {
181        let url = self.build_url(endpoint, None)?;
182
183        let mut form = multipart::Form::new();
184        for field in form_fields {
185            match field {
186                FormField::Text { name, value } => {
187                    form = form.text(name, value);
188                }
189                FormField::File {
190                    name,
191                    filename,
192                    content,
193                } => {
194                    let part = multipart::Part::bytes(content).file_name(filename);
195                    form = form.part(name, part);
196                }
197            }
198        }
199
200        let mut combined_headers = self.default_headers.clone();
201        if let Some(custom_headers) = headers {
202            combined_headers.merge(custom_headers);
203        }
204
205        // Remove Content-Type to let reqwest set the correct multipart/form-data with boundary
206        combined_headers.remove("Content-Type");
207        combined_headers.remove("content-type");
208
209        let mut request = self.client.post(url).multipart(form);
210        request = request.headers(combined_headers.inner().clone());
211
212        let response = request.send().await?;
213        Ok(response.into())
214    }
215
216    /// Send a streaming POST request and return the response stream.
217    pub async fn post_stream(
218        &self,
219        endpoint: &str,
220        body: &serde_json::Value,
221        headers: Option<HeaderMap>,
222    ) -> Result<ByteStream> {
223        let url = self.build_url(endpoint, None)?;
224        let mut request = self.client.post(url).json(body);
225
226        let mut combined_headers = self.default_headers.clone();
227        if let Some(custom_headers) = headers {
228            combined_headers.merge(custom_headers);
229        }
230        request = request.headers(combined_headers.inner().clone());
231
232        let response = request.send().await?;
233        if !response.status().is_success() {
234            return Err(Error::ErrorMessage(
235                format!("Unexpected status: {}", response.status()).into(),
236            ));
237        }
238
239        let stream = response
240            .bytes_stream()
241            .map(|chunk_result| chunk_result.map_err(Error::from));
242        Ok(Box::pin(stream))
243    }
244
245    /// Build a full URL by combining base URL, endpoint, and optional query parameters.
246    fn build_url(&self, endpoint: &str, query: Option<Vec<(String, String)>>) -> Result<Url> {
247        let mut url = if let Some(base_url) = &self.base_url {
248            base_url.join(endpoint)?
249        } else {
250            Url::parse(endpoint)?
251        };
252
253        if let Some(query_params) = query {
254            let query_pairs: Vec<(String, String)> = query_params.into_iter().collect();
255            url.query_pairs_mut().extend_pairs(query_pairs);
256        }
257
258        Ok(url)
259    }
260}
261
262/// Parse a full URL with optional query parameters.
263pub fn parse_url(url: &str, query: Option<Vec<(String, String)>>) -> Result<Url> {
264    let mut url = Url::parse(url)?;
265    if let Some(query_params) = query {
266        let query_pairs: Vec<(String, String)> = query_params.into_iter().collect();
267        url.query_pairs_mut().extend_pairs(query_pairs);
268    }
269    Ok(url)
270}
271
272/// Represents a field in a multipart/form-data request.
273#[derive(Debug, Clone)]
274pub enum FormField {
275    /// A text field.
276    Text { name: String, value: String },
277    /// A file field.
278    File {
279        name: String,
280        filename: String,
281        content: Vec<u8>,
282    },
283}
284
285impl FormField {
286    /// Create a text field.
287    ///
288    /// # Example
289    /// ```
290    /// use toolcraft_request::FormField;
291    /// let field = FormField::text("username", "john_doe");
292    /// ```
293    pub fn text(name: impl Into<String>, value: impl Into<String>) -> Self {
294        FormField::Text {
295            name: name.into(),
296            value: value.into(),
297        }
298    }
299
300    /// Create a file field from bytes.
301    ///
302    /// # Example
303    /// ```
304    /// use toolcraft_request::FormField;
305    /// let data = b"file content".to_vec();
306    /// let field = FormField::file_from_bytes("avatar", "photo.jpg", data);
307    /// ```
308    pub fn file_from_bytes(
309        name: impl Into<String>,
310        filename: impl Into<String>,
311        content: Vec<u8>,
312    ) -> Self {
313        FormField::File {
314            name: name.into(),
315            filename: filename.into(),
316            content,
317        }
318    }
319
320    /// Create a file field by reading from a file path.
321    ///
322    /// # Example
323    /// ```no_run
324    /// use toolcraft_request::FormField;
325    ///
326    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
327    /// let field = FormField::file("avatar", "/path/to/image.jpg").await?;
328    /// # Ok(())
329    /// # }
330    /// ```
331    pub async fn file(name: impl Into<String>, path: impl AsRef<std::path::Path>) -> Result<Self> {
332        let path = path.as_ref();
333        let filename = path
334            .file_name()
335            .and_then(|n| n.to_str())
336            .ok_or_else(|| Error::ErrorMessage("Invalid file path".into()))?
337            .to_string();
338
339        let content = tokio::fs::read(path)
340            .await
341            .map_err(|e| Error::ErrorMessage(format!("Failed to read file: {}", e).into()))?;
342
343        Ok(FormField::File {
344            name: name.into(),
345            filename,
346            content,
347        })
348    }
349}