Skip to main content

zjhttpc/
requestx.rs

1use hashbrown::HashMap;
2use indexmap::IndexSet;
3use serde::Serialize;
4use url::Url;
5
6use anyhow_ext::{Context, Result};
7use async_std::fs::File;
8use futures::io::BufReader;
9use std::time::Duration;
10
11use crate::{
12    body::{Body, BodyForm, BodyMultipartForm},
13    cookie::Cookie,
14    error::ZjhttpcError,
15    misc::TrustStorePem,
16    proxy::HttpsProxyOption,
17};
18
19pub struct Request {
20    pub method: &'static str,
21    pub url: Url,
22    pub headers: HashMap<String, IndexSet<String>>,
23    pub expect_continue: bool,
24    pub content_type: &'static str,
25    pub basic_auth: Option<(String, String)>,
26    pub content_length: u64,
27    pub send_header_timeout: Option<Duration>,
28    pub read_header_timeout: Option<Duration>,
29    pub read_body_timeout: Option<Duration>,
30    pub connect_timeout: Option<Duration>,
31    pub body: Body,
32    pub trust_store_pem: Option<TrustStorePem>,
33    pub proxy: Option<HttpsProxyOption>,
34}
35
36const LIB_VERSION: &str = env!("CARGO_PKG_VERSION");
37
38impl Request {
39    #[must_use]
40    pub fn new(method: &'static str, url: impl AsRef<str>) -> Result<Self> {
41        let url: Url = url.as_ref().parse()?;
42        let host = url.host_str().ok_or_else(|| ZjhttpcError::NoHost).dot()?;
43        let mut headers = HashMap::new();
44        headers.insert("host".to_owned(), IndexSet::from([host.to_owned()]));
45        headers.insert("user-agent".to_owned(), IndexSet::from([format!("zjhttpc/{LIB_VERSION} (powered by Jinhui)")]));
46        Ok(Request {
47            method,
48            url,
49            headers,
50            expect_continue: false,
51            content_type: "application/octet-stream",
52            basic_auth: None,
53            body: Body::None,
54            content_length: 0,
55            send_header_timeout: None,
56            read_header_timeout: None,
57            read_body_timeout: None,
58            connect_timeout: None,
59            trust_store_pem: None,
60            proxy: None,
61        })
62    }
63
64    pub fn method(mut self, method: &'static str) -> Self {
65        self.method = method;
66        self
67    }
68
69    pub fn add_header(mut self, key: impl AsRef<str>, value: impl AsRef<str>) -> Self {
70        if let Some(v) = self.headers.get_mut(key.as_ref()) {
71            v.insert(value.as_ref().to_owned());
72        } else {
73            self.headers
74                .insert(key.as_ref().to_owned(), IndexSet::from([value.as_ref().to_owned()]));
75        }
76        self
77    }
78
79    pub fn set_header(mut self, key: impl AsRef<str>, value: impl AsRef<str>) -> Self {
80        self.headers
81            .insert(key.as_ref().to_owned(), IndexSet::from([value.as_ref().to_owned()]));
82        self
83    }
84
85    pub fn set_headers(mut self, headers: HashMap<String, IndexSet<String>>) -> Self {
86        self.headers.extend(headers);
87        self
88    }
89
90    pub fn set_headers_nondup(
91        mut self,
92        headers: std::collections::HashMap<String, String>,
93    ) -> Self {
94        self.headers.extend(
95            headers
96                .into_iter()
97                .map(|(k, v)| (k, IndexSet::from([v]))),
98        );
99        self
100    }
101
102    /// Set cookies for the request
103    ///
104    /// # Arguments
105    /// * `cookies` - Slice of cookies to set
106    ///
107    /// # Examples
108    /// ```
109    /// use zjhttpc::requestx::Request;
110    /// use zjhttpc::cookie::Cookie;
111    ///
112    /// # fn main() -> anyhow::Result<()> {
113    /// let cookies = vec![
114    ///     Cookie::new("sessionid", "abc123"),
115    ///     Cookie::new("userdata", "eyJ1c2VyIjoiYWxpY2UifQ=="),
116    /// ];
117    ///
118    /// let request = Request::new("GET", "https://example.com/dashboard")?
119    ///     .set_cookie(&cookies);
120    /// # Ok(())
121    /// # }
122    /// ```
123    pub fn set_cookie(mut self, cookies: &[Cookie]) -> Self {
124        let cookie_header = Cookie::format_for_request_cookie_header(cookies);
125        self.headers
126            .insert(crate::header::COOKIE.to_owned(), IndexSet::from([cookie_header]));
127        self
128    }
129
130    pub fn set_queries_serde(mut self, queries: &impl Serialize) -> Result<Self> {
131        let s = serde_qs::to_string(queries).dot()?;
132        self.url.set_query(Some(s.as_str()));
133        Ok(self)
134    }
135
136    pub fn add_query(mut self, key: &str, value: &str) -> Self {
137        self.url.query_pairs_mut().append_pair(key, value);
138        self
139    }
140
141    pub fn header_one(&self, key: impl AsRef<str>) -> Option<&String> {
142        self.headers.get(key.as_ref()).and_then(|set| set.first())
143    }
144
145    pub fn header_all(&self, key: impl AsRef<str>) -> Option<&IndexSet<String>> {
146        self.headers.get(key.as_ref())
147    }
148
149    pub fn put_expect_continue(mut self) -> Self {
150        self.expect_continue = true;
151        self
152    }
153
154    pub fn set_content_type(mut self, content_type: &'static str) -> Self {
155        self.content_type = content_type;
156        self
157    }
158
159    pub fn set_content_length(mut self, len: u64) -> Self {
160        self.content_length = len;
161        self
162    }
163
164    pub fn set_basic_auth(mut self, username: impl AsRef<str>, password: impl AsRef<str>) -> Self {
165        self.basic_auth = Some((username.as_ref().to_owned(), password.as_ref().to_owned()));
166        self
167    }
168
169    pub fn set_body_string(mut self, body: impl AsRef<str>) -> Self {
170        self.content_length = body.as_ref().len() as u64;
171        self.body = Body::Str(body.as_ref().to_owned());
172        self
173    }
174
175    pub fn set_body_stream<R>(mut self, body: R, length: u64) -> Self
176    where
177        R: async_std::io::Read + Unpin + Send + Sync + 'static,
178    {
179        self.content_length = length;
180        self.body = Body::Stream(Box::new(body));
181        self
182    }
183
184    pub async fn set_body_file(mut self, file_path: impl AsRef<std::path::Path>) -> Result<Self> {
185        let p = async_std::path::PathBuf::from(file_path.as_ref());
186        let len = p.metadata().await.dot()?.len();
187        self.content_length = len;
188        let file = File::open(p).await.dot()?;
189        let buf_reader = BufReader::new(file);
190        self.body = Body::Stream(Box::new(buf_reader));
191        Ok(self)
192    }
193
194    pub fn set_body_slice(mut self, body: impl AsRef<[u8]>) -> Self {
195        let bytes = body.as_ref();
196        self.content_length = bytes.len() as u64;
197        self.body = Body::Bytes(bytes.to_vec());
198        self
199    }
200
201    /// Set the request body as application/x-www-form-urlencoded form data.
202    ///
203    /// This method automatically sets the Content-Type header to
204    /// "application/x-www-form-urlencoded", overwriting any previous value.
205    ///
206    /// # Arguments
207    /// * `form` - A BodyForm instance containing the form fields
208    ///
209    /// # Examples
210    /// ```
211    /// use zjhttpc::body::BodyForm;
212    /// use zjhttpc::requestx::Request;
213    ///
214    /// # fn main() -> anyhow::Result<()> {
215    /// let form = BodyForm::new()
216    ///     .add("username", "alice")
217    ///     .add("password", "secret")
218    ///     .add("tags", "rust")
219    ///     .add("tags", "http");
220    ///
221    /// let request = Request::new("POST", "https://example.com/login")?
222    ///     .set_body_form(form);
223    /// # Ok(())
224    /// # }
225    /// ```
226    #[must_use]
227    pub fn set_body_form(mut self, form: BodyForm) -> Self {
228        // Auto-set Content-Type to application/x-www-form-urlencoded
229        self.content_type = "application/x-www-form-urlencoded";
230
231        // Serialize the form data
232        let serialized = form.serialize();
233
234        // Set the body
235        self.content_length = serialized.len() as u64;
236        self.body = Body::Str(serialized);
237
238        self
239    }
240
241    /// Set the request body as multipart/form-data.
242    ///
243    /// This method automatically sets the Content-Type header to
244    /// "multipart/form-data; boundary=XXXX", overwriting any previous value.
245    ///
246    /// The actual serialization happens when sending the request to avoid
247    /// loading entire files into memory.
248    ///
249    /// # Arguments
250    /// * `form` - A BodyMultipartForm instance containing the form fields
251    ///
252    /// # Examples
253    /// ```
254    /// use zjhttpc::body::BodyMultipartForm;
255    /// use zjhttpc::requestx::Request;
256    /// use std::path::PathBuf;
257    ///
258    /// # fn main() -> anyhow::Result<()> {
259    /// let form = BodyMultipartForm::new()
260    ///     .add("username", "alice")
261    ///     .add("bio", "Hello, world!")
262    ///     .add_file_path("avatar", PathBuf::from("/path/to/avatar.jpg"))?;
263    ///
264    /// let request = Request::new("POST", "https://example.com/upload")?
265    ///     .set_body_multipart_form(form);
266    /// # Ok(())
267    /// # }
268    /// ```
269    #[must_use]
270    pub fn set_body_multipart_form(mut self, form: BodyMultipartForm) -> Self {
271        // Auto-set Content-Type to multipart/form-data with boundary
272        let boundary = form.boundary().to_string();
273        self.content_type = Box::leak(
274            format!("multipart/form-data; boundary={}", boundary)
275                .into_boxed_str()
276        );
277
278        // For multipart forms, we can't know the content-length upfront
279        // without reading all files, so set to 0 (will use chunked encoding)
280        self.content_length = 0;
281        self.body = Body::MultipartForm(form);
282
283        self
284    }
285
286    pub fn set_send_header_timeout(mut self, dur: Duration) -> Self {
287        self.send_header_timeout = Some(dur);
288        self
289    }
290
291    pub fn set_read_header_timeout(mut self, dur: Duration) -> Self {
292        self.read_header_timeout = Some(dur);
293        self
294    }
295
296    pub fn set_read_body_timeout(mut self, dur: Duration) -> Self {
297        self.read_body_timeout = Some(dur);
298        self
299    }
300
301    /// Deprecated: Use set_read_header_timeout instead
302    pub fn set_header_timeout(mut self, dur: Duration) -> Self {
303        self.read_header_timeout = Some(dur);
304        self
305    }
306
307    pub fn set_proxy(mut self, proxy: HttpsProxyOption) -> Self {
308        self.proxy = Some(proxy);
309        self
310    }
311
312    pub fn set_proxy_from_url(mut self, proxy_url: impl AsRef<str>) -> Result<Self> {
313        let proxy = HttpsProxyOption::new(proxy_url)?;
314        self.proxy = Some(proxy);
315        Ok(self)
316    }
317
318    pub fn set_connect_timeout(mut self, dur: Duration) -> Self {
319        self.connect_timeout = Some(dur);
320        self
321    }
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327    use url::Url;
328
329    #[test]
330    fn test_url_parsing() {
331        // Test basic URL parsing
332        let url = Url::parse("http://example.com/path").unwrap();
333        assert_eq!(url.scheme(), "http");
334        assert_eq!(url.host_str().unwrap(), "example.com");
335        assert_eq!(url.path(), "/path");
336        println!("{x:?}", x = url.fragment());
337
338        // Test HTTPS URL
339        let url = Url::parse("https://example.com:443/secure").unwrap();
340        assert_eq!(url.scheme(), "https");
341        assert_eq!(url.port(), None); // wried
342
343        let url = Url::parse("https://example.com:1443/secure").unwrap();
344        assert_eq!(url.scheme(), "https");
345        assert_eq!(url.port(), Some(1443)); // wried
346
347        // Test URL with query parameters
348        let url = Url::parse("http://example.com/search?q=test&page=1").unwrap();
349        assert_eq!(url.query(), Some("q=test&page=1"));
350
351        // Test URL with basic auth
352        let url = Url::parse("http://user:pass@example.com").unwrap();
353        assert_eq!(url.username(), "user");
354        assert_eq!(url.password(), Some("pass"));
355
356        // Test invalid URL
357        assert!(Url::parse("not a url").is_err());
358    }
359
360    #[test]
361    fn test_url_set_query() {
362        let mut url = Url::parse("http://user:pass@example.com").unwrap();
363        url.query_pairs_mut().append_pair("a", "b");
364        url.query_pairs_mut().append_pair("c", "d");
365        // url.set_query(Some("c=d"));
366        println!("{x}", x = url.to_string())
367    }
368
369    #[test]
370    fn test_request_proxy_configuration() {
371        let mut request = Request::new("GET", "http://example.com").unwrap();
372        assert!(request.proxy.is_none());
373
374        let proxy = crate::proxy::HttpsProxyOption::new("http://proxy.example.com:8080").unwrap();
375        request = request.set_proxy(proxy.clone());
376        assert!(request.proxy.is_some());
377        assert_eq!(request.proxy.unwrap().url.host_str().unwrap(), "proxy.example.com");
378    }
379
380    #[test]
381    fn test_request_proxy_from_url() {
382        let result = Request::new("GET", "http://example.com").unwrap()
383            .set_proxy_from_url("http://proxy.example.com:8080");
384        assert!(result.is_ok());
385        let request = result.unwrap();
386        assert!(request.proxy.is_some());
387        assert_eq!(request.proxy.unwrap().url.host_str().unwrap(), "proxy.example.com");
388    }
389
390    #[test]
391    fn test_request_invalid_proxy_url() {
392        let result = Request::new("GET", "http://example.com").unwrap()
393            .set_proxy_from_url("invalid-url");
394        assert!(result.is_err());
395    }
396
397    #[test]
398    fn test_request_connect_timeout() {
399        let request = Request::new("GET", "http://example.com").unwrap()
400            .set_connect_timeout(Duration::from_secs(5));
401        assert_eq!(request.connect_timeout, Some(Duration::from_secs(5)));
402    }
403
404    #[test]
405    fn test_request_connect_timeout_default() {
406        let request = Request::new("GET", "http://example.com").unwrap();
407        assert_eq!(request.connect_timeout, None);
408    }
409
410    #[test]
411    fn test_add_query_to_url_without_existing_query() {
412        let request = Request::new("GET", "http://example.com")
413            .unwrap()
414            .add_query("param1", "value1")
415            .add_query("param2", "value2");
416        
417        assert_eq!(request.url.query(), Some("param1=value1&param2=value2"));
418    }
419
420    #[test]
421    fn test_add_query_to_url_with_existing_query() {
422        let request = Request::new("GET", "http://example.com?existing=test")
423            .unwrap()
424            .add_query("param1", "value1");
425        
426        assert_eq!(request.url.query(), Some("existing=test&param1=value1"));
427    }
428
429    #[test]
430    fn test_add_query_with_special_characters() {
431        let request = Request::new("GET", "http://example.com")
432            .unwrap()
433            .add_query("query", "hello world")
434            .add_query("symbol", "@#$%");
435        
436        let query = request.url.query().unwrap();
437        assert!(query.contains("query=hello+world"));
438        assert!(query.contains("symbol=%40%23%24%25"));
439    }
440
441    #[test]
442    fn test_add_query_with_empty_values() {
443        let request = Request::new("GET", "http://example.com")
444            .unwrap()
445            .add_query("empty", "")
446            .add_query("param", "value");
447        
448        assert_eq!(request.url.query(), Some("empty=&param=value"));
449    }
450
451    #[test]
452    fn test_add_query_with_duplicate_keys() {
453        let request = Request::new("GET", "http://example.com")
454            .unwrap()
455            .add_query("key", "value1")
456            .add_query("key", "value2");
457        
458        let query = request.url.query().unwrap();
459        assert!(query.contains("key=value1"));
460        assert!(query.contains("key=value2"));
461        assert_eq!(query, "key=value1&key=value2");
462    }
463
464    #[test]
465    fn test_add_query_to_https_url() {
466        let request = Request::new("GET", "https://api.example.com/endpoint")
467            .unwrap()
468            .add_query("api_key", "secret123")
469            .add_query("format", "json");
470        
471        assert_eq!(request.url.query(), Some("api_key=secret123&format=json"));
472        assert_eq!(request.url.scheme(), "https");
473        assert_eq!(request.url.path(), "/endpoint");
474    }
475
476    #[test]
477    fn test_add_query_with_path_and_fragment() {
478        let request = Request::new("GET", "http://example.com/path/to/resource#section")
479            .unwrap()
480            .add_query("filter", "all");
481        
482        assert_eq!(request.url.query(), Some("filter=all"));
483        assert_eq!(request.url.path(), "/path/to/resource");
484        assert_eq!(request.url.fragment(), Some("section"));
485    }
486
487    #[test]
488    fn test_add_query_unicode_characters() {
489        let request = Request::new("GET", "http://example.com")
490            .unwrap()
491            .add_query("emoji", "🚀")
492            .add_query("chinese", "你好");
493        
494        let query = request.url.query().unwrap();
495        assert!(query.contains("emoji=%F0%9F%9A%80"));
496        assert!(query.contains("chinese=%E4%BD%A0%E5%A5%BD"));
497    }
498
499    #[test]
500    fn test_add_query_chainable_api() {
501        let request = Request::new("GET", "http://example.com")
502            .unwrap()
503            .add_query("a", "1")
504            .add_query("b", "2")
505            .add_header("Accept", "application/json")
506            .add_query("c", "3");
507
508        assert_eq!(request.url.query(), Some("a=1&b=2&c=3"));
509        assert!(request.headers.contains_key("Accept"));
510        assert_eq!(request.headers.get("Accept").unwrap().first().unwrap(), "application/json");
511    }
512
513    #[test]
514    fn test_content_type_constants() {
515        use crate::content_type;
516
517        let request = Request::new("POST", "http://example.com")
518            .unwrap()
519            .set_content_type(content_type::APPLICATION_JSON);
520
521        assert_eq!(request.content_type, "application/json");
522
523        let request = Request::new("POST", "http://example.com")
524            .unwrap()
525            .set_content_type(content_type::TEXT_HTML);
526
527        assert_eq!(request.content_type, "text/html");
528
529        let request = Request::new("POST", "http://example.com")
530            .unwrap()
531            .set_content_type(content_type::IMAGE_PNG);
532
533        assert_eq!(request.content_type, "image/png");
534    }
535}