Skip to main content

tork_core/testing/
request.rs

1//! Request builders for the test client.
2
3use std::sync::Arc;
4
5use bytes::Bytes;
6use http::header::CONTENT_TYPE;
7use http::{HeaderName, HeaderValue, Method};
8use serde::Serialize;
9
10use super::client::{Shared, TestHeader};
11use super::response::TestResponse;
12use crate::error::{Error, Result};
13
14/// Boundary used for generated multipart bodies. Unlikely to occur in test data.
15const MULTIPART_BOUNDARY: &str = "----TorkTestBoundary7MA4YWxkTrZu0gW";
16
17/// A pending request body and its content type.
18#[derive(Default)]
19pub(crate) struct PendingBody {
20    pub(crate) content_type: Option<HeaderValue>,
21    pub(crate) bytes: Bytes,
22}
23
24/// Builds and sends a single HTTP request.
25///
26/// Created by the verb methods on [`TestClient`](super::TestClient)
27/// (`get`/`post`/...). Set headers, query parameters, and a body, then call
28/// [`send`](TestRequestBuilder::send).
29pub struct TestRequestBuilder {
30    shared: Arc<Shared>,
31    method: Method,
32    path: String,
33    query: Vec<(String, String)>,
34    headers: Vec<TestHeader>,
35    body: PendingBody,
36}
37
38impl TestRequestBuilder {
39    pub(crate) fn new(shared: Arc<Shared>, method: Method, path: impl Into<String>) -> Self {
40        Self {
41            shared,
42            method,
43            path: path.into(),
44            query: Vec::new(),
45            headers: Vec::new(),
46            body: PendingBody::default(),
47        }
48    }
49
50    /// Adds a request header. An invalid name or value is ignored.
51    pub fn header(mut self, name: &str, value: &str) -> Self {
52        if let (Ok(name), Ok(value)) = (
53            HeaderName::from_bytes(name.as_bytes()),
54            HeaderValue::from_str(value),
55        ) {
56            self.headers.push(TestHeader::safe(name, value));
57        }
58        self
59    }
60
61    /// Adds a security-sensitive request header, bypassing the in-process guard.
62    pub fn unsafe_header(mut self, name: &str, value: &str) -> Self {
63        if let (Ok(name), Ok(value)) = (
64            HeaderName::from_bytes(name.as_bytes()),
65            HeaderValue::from_str(value),
66        ) {
67            self.headers.push(TestHeader::unsafe_allowed(name, value));
68        }
69        self
70    }
71
72    /// Adds a query parameter.
73    pub fn query(mut self, name: &str, value: &str) -> Self {
74        self.query.push((name.to_owned(), value.to_owned()));
75        self
76    }
77
78    /// Sets a JSON body (`application/json`).
79    pub fn json<T: Serialize>(mut self, value: &T) -> Self {
80        match serde_json::to_vec(value) {
81            Ok(bytes) => {
82                self.body = PendingBody {
83                    content_type: Some(HeaderValue::from_static("application/json")),
84                    bytes: Bytes::from(bytes),
85                };
86            }
87            Err(_) => self.body = PendingBody::default(),
88        }
89        self
90    }
91
92    /// Sets a urlencoded form body (`application/x-www-form-urlencoded`).
93    pub fn form<T: Serialize>(mut self, value: &T) -> Self {
94        match serde_urlencoded::to_string(value) {
95            Ok(text) => {
96                self.body = PendingBody {
97                    content_type: Some(HeaderValue::from_static(
98                        "application/x-www-form-urlencoded",
99                    )),
100                    bytes: Bytes::from(text.into_bytes()),
101                };
102            }
103            Err(_) => self.body = PendingBody::default(),
104        }
105        self
106    }
107
108    /// Sets a raw byte body, with no content type unless one is set via `header`.
109    pub fn bytes(mut self, bytes: impl Into<Bytes>) -> Self {
110        self.body = PendingBody {
111            content_type: None,
112            bytes: bytes.into(),
113        };
114        self
115    }
116
117    /// Switches to a `multipart/form-data` body builder.
118    pub fn multipart(self) -> TestMultipartBuilder {
119        TestMultipartBuilder {
120            shared: self.shared,
121            method: self.method,
122            path: self.path,
123            query: self.query,
124            headers: self.headers,
125            parts: Vec::new(),
126        }
127    }
128
129    /// Sends the request and returns the response.
130    pub async fn send(self) -> Result<TestResponse> {
131        self.shared
132            .send(self.method, self.path, self.query, self.headers, self.body)
133            .await
134    }
135
136    /// Sends the request and reads the response as a Server-Sent Events stream.
137    pub async fn sse(self) -> Result<super::sse::TestSseStream> {
138        self.shared
139            .open_sse(self.method, self.path, self.query, self.headers)
140            .await
141    }
142}
143
144/// One part of a multipart body: a text field or a file.
145struct MultipartPart {
146    name: String,
147    filename: Option<String>,
148    content_type: Option<String>,
149    value: Bytes,
150}
151
152/// Builds and sends a `multipart/form-data` request (forms with files).
153pub struct TestMultipartBuilder {
154    shared: Arc<Shared>,
155    method: Method,
156    path: String,
157    query: Vec<(String, String)>,
158    headers: Vec<TestHeader>,
159    parts: Vec<MultipartPart>,
160}
161
162impl TestMultipartBuilder {
163    /// Adds a text field.
164    pub fn text(mut self, name: &str, value: &str) -> Self {
165        self.parts.push(MultipartPart {
166            name: name.to_owned(),
167            filename: None,
168            content_type: None,
169            value: Bytes::from(value.to_owned().into_bytes()),
170        });
171        self
172    }
173
174    /// Adds a file field with the given filename, content type, and bytes.
175    pub fn file_bytes(
176        mut self,
177        name: &str,
178        filename: &str,
179        content_type: &str,
180        bytes: impl Into<Bytes>,
181    ) -> Self {
182        self.parts.push(MultipartPart {
183            name: name.to_owned(),
184            filename: Some(filename.to_owned()),
185            content_type: Some(content_type.to_owned()),
186            value: bytes.into(),
187        });
188        self
189    }
190
191    /// Adds a request header.
192    pub fn header(mut self, name: &str, value: &str) -> Self {
193        if let (Ok(name), Ok(value)) = (
194            HeaderName::from_bytes(name.as_bytes()),
195            HeaderValue::from_str(value),
196        ) {
197            self.headers.push(TestHeader::safe(name, value));
198        }
199        self
200    }
201
202    /// Adds a security-sensitive request header, bypassing the in-process guard.
203    pub fn unsafe_header(mut self, name: &str, value: &str) -> Self {
204        if let (Ok(name), Ok(value)) = (
205            HeaderName::from_bytes(name.as_bytes()),
206            HeaderValue::from_str(value),
207        ) {
208            self.headers.push(TestHeader::unsafe_allowed(name, value));
209        }
210        self
211    }
212
213    /// Adds a query parameter.
214    pub fn query(mut self, name: &str, value: &str) -> Self {
215        self.query.push((name.to_owned(), value.to_owned()));
216        self
217    }
218
219    /// Encodes the parts and sends the request.
220    pub async fn send(self) -> Result<TestResponse> {
221        let mut body = Vec::new();
222        for part in &self.parts {
223            body.extend_from_slice(format!("--{MULTIPART_BOUNDARY}\r\n").as_bytes());
224            match (&part.filename, &part.content_type) {
225                (Some(filename), content_type) => {
226                    body.extend_from_slice(
227                        format!(
228                            "Content-Disposition: form-data; name=\"{}\"; filename=\"{}\"\r\n",
229                            part.name, filename
230                        )
231                        .as_bytes(),
232                    );
233                    if let Some(content_type) = content_type {
234                        body.extend_from_slice(
235                            format!("Content-Type: {content_type}\r\n").as_bytes(),
236                        );
237                    }
238                }
239                (None, _) => {
240                    body.extend_from_slice(
241                        format!("Content-Disposition: form-data; name=\"{}\"\r\n", part.name)
242                            .as_bytes(),
243                    );
244                }
245            }
246            body.extend_from_slice(b"\r\n");
247            body.extend_from_slice(&part.value);
248            body.extend_from_slice(b"\r\n");
249        }
250        body.extend_from_slice(format!("--{MULTIPART_BOUNDARY}--\r\n").as_bytes());
251
252        let content_type = HeaderValue::from_str(&format!(
253            "multipart/form-data; boundary={MULTIPART_BOUNDARY}"
254        ))
255        .map_err(|_| Error::internal("failed to build multipart content type"))?;
256        let pending = PendingBody {
257            content_type: Some(content_type),
258            bytes: Bytes::from(body),
259        };
260        self.shared
261            .send(self.method, self.path, self.query, self.headers, pending)
262            .await
263    }
264}
265
266/// The `Content-Type` header name, re-exported for the client module.
267pub(crate) const CONTENT_TYPE_HEADER: HeaderName = CONTENT_TYPE;
268
269#[cfg(test)]
270mod tests {
271    use super::super::client::{Shared, Transport};
272    use super::*;
273    use crate::app::App;
274    use std::sync::Mutex;
275
276    #[derive(serde::Serialize)]
277    struct Query {
278        word: &'static str,
279    }
280
281    struct BrokenSerialize;
282
283    impl Serialize for BrokenSerialize {
284        fn serialize<S>(&self, _serializer: S) -> std::result::Result<S::Ok, S::Error>
285        where
286            S: serde::Serializer,
287        {
288            Err(serde::ser::Error::custom("boom"))
289        }
290    }
291
292    fn shared() -> Arc<Shared> {
293        Arc::new(Shared {
294            transport: Transport::InProcess(Arc::new(App::new().build().unwrap())),
295            default_headers: http::HeaderMap::new(),
296            unsafe_default_headers: http::HeaderMap::new(),
297            cookies: Mutex::new(super::super::cookie::CookieJar::default()),
298        })
299    }
300
301    #[test]
302    fn json_and_form_reset_body_on_serialize_failure() {
303        let request = TestRequestBuilder::new(shared(), Method::POST, "/items")
304            .json(&BrokenSerialize)
305            .form(&BrokenSerialize);
306
307        assert!(request.body.content_type.is_none());
308        assert!(request.body.bytes.is_empty());
309    }
310
311    #[test]
312    fn builder_collects_headers_query_and_bytes() {
313        let request = TestRequestBuilder::new(shared(), Method::PUT, "/items")
314            .header("x-test", "1")
315            .header("\n", "ignored")
316            .query("q", "space value")
317            .bytes(Bytes::from_static(b"payload"));
318
319        assert_eq!(request.headers.len(), 1);
320        assert_eq!(
321            request.query,
322            vec![("q".to_owned(), "space value".to_owned())]
323        );
324        assert_eq!(request.body.bytes, Bytes::from_static(b"payload"));
325        assert!(request.body.content_type.is_none());
326        assert!(!request.headers[0].unsafe_allowed);
327    }
328
329    #[test]
330    fn unsafe_header_marks_the_entry() {
331        let request = TestRequestBuilder::new(shared(), Method::GET, "/items")
332            .unsafe_header("host", "example.com");
333
334        assert_eq!(request.headers.len(), 1);
335        assert!(request.headers[0].unsafe_allowed);
336    }
337
338    #[tokio::test]
339    async fn multipart_builder_encodes_text_and_file_parts() {
340        let response = TestRequestBuilder::new(shared(), Method::POST, "/upload")
341            .multipart()
342            .text("title", "hello")
343            .file_bytes("file", "note.txt", "text/plain", "payload")
344            .query("kind", "docs")
345            .header("x-test", "1")
346            .send()
347            .await
348            .unwrap();
349
350        assert_eq!(response.status(), 404);
351    }
352
353    #[test]
354    fn form_uses_urlencoding() {
355        let request = TestRequestBuilder::new(shared(), Method::POST, "/search").form(&Query {
356            word: "hello world",
357        });
358
359        assert_eq!(
360            request.body.content_type,
361            Some(HeaderValue::from_static(
362                "application/x-www-form-urlencoded"
363            ))
364        );
365        assert_eq!(request.body.bytes, Bytes::from_static(b"word=hello+world"));
366    }
367}