salvo_core/test/request/
builder.rs

1use std::borrow::Borrow;
2use std::str;
3use std::sync::Arc;
4
5use base64::engine::{general_purpose, Engine};
6use http::header::{self, HeaderMap, HeaderValue, IntoHeaderName};
7use http::uri::Scheme;
8use url::Url;
9
10use crate::http::body::ReqBody;
11use crate::http::Method;
12use crate::routing::{FlowCtrl, Router};
13use crate::{Depot, Error, Handler, Request, Response, Service};
14
15/// The main way of building [`Request`].
16///
17/// You can create a `RequestBuilder` using the `new` or `try_new` method, but the recommended way
18/// or use one of the simpler constructors available in the [`TestClient`](crate::test::TestClient) struct,
19/// such as `get`, `post`, etc.
20#[derive(Debug)]
21pub struct RequestBuilder {
22    url: Url,
23    method: Method,
24    headers: HeaderMap,
25    // params: HashMap<String, String>,
26    body: ReqBody,
27}
28
29impl RequestBuilder {
30    /// Create a new `RequestBuilder` with the base URL and the given method.
31    ///
32    /// # Panics
33    /// Panics if the base url is invalid or if the method is CONNECT.
34    #[must_use]
35    pub fn new<U>(url: U, method: Method) -> Self
36    where
37        U: AsRef<str>,
38    {
39        let url = Url::parse(url.as_ref()).expect("invalid url");
40        Self {
41            url,
42            method,
43            headers: HeaderMap::new(),
44            // params: HeaderMap::new(),
45            body: ReqBody::None,
46        }
47    }
48}
49
50impl RequestBuilder {
51    /// Associate a query string parameter to the given value.
52    ///
53    /// The same key can be used multiple times.
54    #[must_use]
55    pub fn query<K, V>(mut self, key: K, value: V) -> Self
56    where
57        K: AsRef<str>,
58        V: ToString,
59    {
60        self.url.query_pairs_mut().append_pair(key.as_ref(), &value.to_string());
61        self
62    }
63
64    /// Associated a list of pairs to query parameters.
65    ///
66    /// The same key can be used multiple times.
67    ///
68    /// # Example
69    /// ```ignore
70    /// TestClient::get("http://foo.bar").queries(&[("p1", "v1"), ("p2", "v2")]);
71    /// ```
72    #[must_use]
73    pub fn queries<P, K, V>(mut self, pairs: P) -> Self
74    where
75        P: IntoIterator,
76        P::Item: Borrow<(K, V)>,
77        K: AsRef<str>,
78        V: ToString,
79    {
80        for pair in pairs.into_iter() {
81            let (key, value) = pair.borrow();
82            self.url.query_pairs_mut().append_pair(key.as_ref(), &value.to_string());
83        }
84        self
85    }
86
87    // /// Associate a url param to the given value.
88    // pub fn param<K, V>(mut self, key: K, value: V) -> Self
89    // where
90    //     K: AsRef<str>,
91    //     V: ToString,
92    // {
93    //     self.params.insert(key.as_ref(), &value.to_string());
94    //     self
95    // }
96
97    // /// Associated a list of url params.
98    // pub fn params<P, K, V>(mut self, pairs: P) -> Self
99    // where
100    //     P: IntoIterator,
101    //     P::Item: Borrow<(K, V)>,
102    //     K: AsRef<str>,
103    //     V: ToString,
104    // {
105    //     for pair in pairs.into_iter() {
106    //         let (key, value) = pair.borrow();
107    //         self.params.insert(key.as_ref(), &value.to_string());
108    //     }
109    //     self
110    // }
111
112    /// Enable HTTP basic authentication.
113    #[must_use]
114    pub fn basic_auth(self, username: impl std::fmt::Display, password: Option<impl std::fmt::Display>) -> Self {
115        let auth = match password {
116            Some(password) => format!("{username}:{password}"),
117            None => format!("{username}:"),
118        };
119        let encoded = format!("Basic {}", general_purpose::STANDARD.encode(auth.as_bytes()));
120        self.add_header(header::AUTHORIZATION, encoded, true)
121    }
122
123    /// Enable HTTP bearer authentication.
124    #[must_use]
125    pub fn bearer_auth(self, token: impl Into<String>) -> Self {
126        self.add_header(header::AUTHORIZATION, format!("Bearer {}", token.into()), true)
127    }
128
129    /// Sets the body of this request.
130    #[must_use]
131    pub fn body(mut self, body: impl Into<ReqBody>) -> Self {
132        self.body = body.into();
133        self
134    }
135
136    /// Sets the body of this request to be text.
137    ///
138    /// If the `Content-Type` header is unset, it will be set to `text/plain` and the charset to UTF-8.
139    #[must_use]
140    pub fn text(mut self, body: impl Into<String>) -> Self {
141        self.headers
142            .entry(header::CONTENT_TYPE)
143            .or_insert(HeaderValue::from_static("text/plain; charset=utf-8"));
144        self.body(body.into())
145    }
146
147    /// Sets the body of this request to be bytes.
148    ///
149    /// If the `Content-Type` header is unset, it will be set to `application/octet-stream`.
150    #[must_use]
151    pub fn bytes(mut self, body: Vec<u8>) -> Self {
152        self.headers
153            .entry(header::CONTENT_TYPE)
154            .or_insert(HeaderValue::from_static("application/octet-stream"));
155        self.body(body)
156    }
157
158    /// Sets the body of this request to be the JSON representation of the given object.
159    ///
160    /// If the `Content-Type` header is unset, it will be set to `application/json` and the charset to UTF-8.
161    #[must_use]
162    pub fn json<T: serde::Serialize>(mut self, value: &T) -> Self {
163        self.headers
164            .entry(header::CONTENT_TYPE)
165            .or_insert(HeaderValue::from_static("application/json; charset=utf-8"));
166        self.body(serde_json::to_vec(value).expect("Failed to serialize json."))
167    }
168
169    /// Sets the body of this request to be the JSON representation of the given string.
170    ///
171    /// If the `Content-Type` header is unset, it will be set to `application/json` and the charset to UTF-8.
172    #[must_use]
173    pub fn raw_json(mut self, value: impl Into<String>) -> Self {
174        self.headers
175            .entry(header::CONTENT_TYPE)
176            .or_insert(HeaderValue::from_static("application/json; charset=utf-8"));
177        self.body(value.into())
178    }
179
180    /// Sets the body of this request to be the URL-encoded representation of the given object.
181    ///
182    /// If the `Content-Type` header is unset, it will be set to `application/x-www-form-urlencoded`.
183    #[must_use]
184    pub fn form<T: serde::Serialize>(mut self, value: &T) -> Self {
185        let body = serde_urlencoded::to_string(value)
186            .expect("`serde_urlencoded::to_string` returns error")
187            .into_bytes();
188        self.headers
189            .entry(header::CONTENT_TYPE)
190            .or_insert(HeaderValue::from_static("application/x-www-form-urlencoded"));
191        self.body(body)
192    }
193    /// Sets the body of this request to be the URL-encoded representation of the given string.
194    ///
195    /// If the `Content-Type` header is unset, it will be set to `application/x-www-form-urlencoded`.
196    #[must_use]
197    pub fn raw_form(mut self, value: impl Into<String>) -> Self {
198        self.headers
199            .entry(header::CONTENT_TYPE)
200            .or_insert(HeaderValue::from_static("application/x-www-form-urlencoded"));
201        self.body(value.into())
202    }
203    /// Modify a header for this response.
204    ///
205    /// When `overwrite` is set to `true`, If the header is already present, the value will be replaced.
206    /// When `overwrite` is set to `false`, The new header is always appended to the request, even if the header already exists.
207    #[must_use]
208    pub fn add_header<N, V>(mut self, name: N, value: V, overwrite: bool) -> Self
209    where
210        N: IntoHeaderName,
211        V: TryInto<HeaderValue>,
212    {
213        let value = value
214            .try_into()
215            .map_err(|_| Error::Other("invalid header value".into()))
216            .expect("invalid header value");
217        if overwrite {
218            self.headers.insert(name, value);
219        } else {
220            self.headers.append(name, value);
221        }
222        self
223    }
224
225    /// Build final request.
226    pub fn build(self) -> Request {
227        let req = self.build_hyper();
228        let scheme = req.uri().scheme().cloned().unwrap_or(Scheme::HTTP);
229        Request::from_hyper(req, scheme)
230    }
231
232    /// Build hyper request.
233    pub fn build_hyper(self) -> hyper::Request<ReqBody> {
234        let Self {
235            url,
236            method,
237            headers,
238            body,
239        } = self;
240        let mut req = hyper::Request::builder().method(method).uri(url.to_string());
241        (*req.headers_mut().expect("`headers_mut` returns `None`")) = headers;
242        req.body(body).expect("invalid request body")
243    }
244
245    /// Send request to target, such as [`Router`], [`Service`], [`Handler`].
246    pub async fn send(self, target: impl SendTarget + Send) -> Response {
247        #[cfg(feature = "cookie")]
248        {
249            let mut response = target.call(self.build()).await;
250            let values = response
251                .cookies
252                .delta()
253                .filter_map(|c| c.encoded().to_string().parse().ok())
254                .collect::<Vec<_>>();
255            for hv in values {
256                response.headers_mut().insert(header::SET_COOKIE, hv);
257            }
258            response
259        }
260        #[cfg(not(feature = "cookie"))]
261        target.call(self.build()).await
262    }
263}
264
265/// Trait for sending request to target, such as [`Router`], [`Service`], [`Handler`] for test usage.
266pub trait SendTarget {
267    /// Send request to target, such as [`Router`], [`Service`], [`Handler`].
268    #[must_use = "future must be used"]
269    fn call(self, req: Request) -> impl Future<Output = Response> + Send;
270}
271impl SendTarget for &Service {
272    async fn call(self, req: Request) -> Response {
273        self.handle(req).await
274    }
275}
276impl SendTarget for Router {
277    async fn call(self, req: Request) -> Response {
278        let router = Arc::new(self);
279        SendTarget::call(router, req).await
280    }
281}
282impl SendTarget for Arc<Router> {
283    async fn call(self, req: Request) -> Response {
284        let srv = Service::new(self);
285        srv.handle(req).await
286    }
287}
288impl<T> SendTarget for Arc<T>
289where
290    T: Handler + Send,
291{
292    async fn call(self, req: Request) -> Response {
293        let mut req = req;
294        let mut depot = Depot::new();
295        #[cfg(not(feature = "cookie"))]
296        let mut res = Response::new();
297        #[cfg(feature = "cookie")]
298        let mut res = Response::with_cookies(req.cookies.clone());
299        let mut ctrl = FlowCtrl::new(vec![self.clone()]);
300        self.handle(&mut req, &mut depot, &mut res, &mut ctrl).await;
301        res
302    }
303}
304impl<T> SendTarget for T
305where
306    T: Handler + Send,
307{
308    async fn call(self, req: Request) -> Response {
309        let handler = Arc::new(self);
310        SendTarget::call(handler, req).await
311    }
312}