1use std::sync::Arc;
4
5use bytes::Bytes;
6use http::header::CONTENT_TYPE;
7use http::{HeaderName, HeaderValue, Method};
8use serde::Serialize;
9
10use super::client::{Shared, TestHeader};
11use super::response::TestResponse;
12use crate::error::{Error, Result};
13
14const MULTIPART_BOUNDARY: &str = "----TorkTestBoundary7MA4YWxkTrZu0gW";
16
17#[derive(Default)]
19pub(crate) struct PendingBody {
20 pub(crate) content_type: Option<HeaderValue>,
21 pub(crate) bytes: Bytes,
22}
23
24pub struct TestRequestBuilder {
30 shared: Arc<Shared>,
31 method: Method,
32 path: String,
33 query: Vec<(String, String)>,
34 headers: Vec<TestHeader>,
35 body: PendingBody,
36}
37
38impl TestRequestBuilder {
39 pub(crate) fn new(shared: Arc<Shared>, method: Method, path: impl Into<String>) -> Self {
40 Self {
41 shared,
42 method,
43 path: path.into(),
44 query: Vec::new(),
45 headers: Vec::new(),
46 body: PendingBody::default(),
47 }
48 }
49
50 pub fn header(mut self, name: &str, value: &str) -> Self {
52 if let (Ok(name), Ok(value)) = (
53 HeaderName::from_bytes(name.as_bytes()),
54 HeaderValue::from_str(value),
55 ) {
56 self.headers.push(TestHeader::safe(name, value));
57 }
58 self
59 }
60
61 pub fn unsafe_header(mut self, name: &str, value: &str) -> Self {
63 if let (Ok(name), Ok(value)) = (
64 HeaderName::from_bytes(name.as_bytes()),
65 HeaderValue::from_str(value),
66 ) {
67 self.headers.push(TestHeader::unsafe_allowed(name, value));
68 }
69 self
70 }
71
72 pub fn query(mut self, name: &str, value: &str) -> Self {
74 self.query.push((name.to_owned(), value.to_owned()));
75 self
76 }
77
78 pub fn json<T: Serialize>(mut self, value: &T) -> Self {
80 match serde_json::to_vec(value) {
81 Ok(bytes) => {
82 self.body = PendingBody {
83 content_type: Some(HeaderValue::from_static("application/json")),
84 bytes: Bytes::from(bytes),
85 };
86 }
87 Err(_) => self.body = PendingBody::default(),
88 }
89 self
90 }
91
92 pub fn form<T: Serialize>(mut self, value: &T) -> Self {
94 match serde_urlencoded::to_string(value) {
95 Ok(text) => {
96 self.body = PendingBody {
97 content_type: Some(HeaderValue::from_static(
98 "application/x-www-form-urlencoded",
99 )),
100 bytes: Bytes::from(text.into_bytes()),
101 };
102 }
103 Err(_) => self.body = PendingBody::default(),
104 }
105 self
106 }
107
108 pub fn bytes(mut self, bytes: impl Into<Bytes>) -> Self {
110 self.body = PendingBody {
111 content_type: None,
112 bytes: bytes.into(),
113 };
114 self
115 }
116
117 pub fn multipart(self) -> TestMultipartBuilder {
119 TestMultipartBuilder {
120 shared: self.shared,
121 method: self.method,
122 path: self.path,
123 query: self.query,
124 headers: self.headers,
125 parts: Vec::new(),
126 }
127 }
128
129 pub async fn send(self) -> Result<TestResponse> {
131 self.shared
132 .send(self.method, self.path, self.query, self.headers, self.body)
133 .await
134 }
135
136 pub async fn sse(self) -> Result<super::sse::TestSseStream> {
138 self.shared
139 .open_sse(self.method, self.path, self.query, self.headers)
140 .await
141 }
142}
143
144struct MultipartPart {
146 name: String,
147 filename: Option<String>,
148 content_type: Option<String>,
149 value: Bytes,
150}
151
152pub struct TestMultipartBuilder {
154 shared: Arc<Shared>,
155 method: Method,
156 path: String,
157 query: Vec<(String, String)>,
158 headers: Vec<TestHeader>,
159 parts: Vec<MultipartPart>,
160}
161
162impl TestMultipartBuilder {
163 pub fn text(mut self, name: &str, value: &str) -> Self {
165 self.parts.push(MultipartPart {
166 name: name.to_owned(),
167 filename: None,
168 content_type: None,
169 value: Bytes::from(value.to_owned().into_bytes()),
170 });
171 self
172 }
173
174 pub fn file_bytes(
176 mut self,
177 name: &str,
178 filename: &str,
179 content_type: &str,
180 bytes: impl Into<Bytes>,
181 ) -> Self {
182 self.parts.push(MultipartPart {
183 name: name.to_owned(),
184 filename: Some(filename.to_owned()),
185 content_type: Some(content_type.to_owned()),
186 value: bytes.into(),
187 });
188 self
189 }
190
191 pub fn header(mut self, name: &str, value: &str) -> Self {
193 if let (Ok(name), Ok(value)) = (
194 HeaderName::from_bytes(name.as_bytes()),
195 HeaderValue::from_str(value),
196 ) {
197 self.headers.push(TestHeader::safe(name, value));
198 }
199 self
200 }
201
202 pub fn unsafe_header(mut self, name: &str, value: &str) -> Self {
204 if let (Ok(name), Ok(value)) = (
205 HeaderName::from_bytes(name.as_bytes()),
206 HeaderValue::from_str(value),
207 ) {
208 self.headers.push(TestHeader::unsafe_allowed(name, value));
209 }
210 self
211 }
212
213 pub fn query(mut self, name: &str, value: &str) -> Self {
215 self.query.push((name.to_owned(), value.to_owned()));
216 self
217 }
218
219 pub async fn send(self) -> Result<TestResponse> {
221 let mut body = Vec::new();
222 for part in &self.parts {
223 body.extend_from_slice(format!("--{MULTIPART_BOUNDARY}\r\n").as_bytes());
224 match (&part.filename, &part.content_type) {
225 (Some(filename), content_type) => {
226 body.extend_from_slice(
227 format!(
228 "Content-Disposition: form-data; name=\"{}\"; filename=\"{}\"\r\n",
229 part.name, filename
230 )
231 .as_bytes(),
232 );
233 if let Some(content_type) = content_type {
234 body.extend_from_slice(
235 format!("Content-Type: {content_type}\r\n").as_bytes(),
236 );
237 }
238 }
239 (None, _) => {
240 body.extend_from_slice(
241 format!("Content-Disposition: form-data; name=\"{}\"\r\n", part.name)
242 .as_bytes(),
243 );
244 }
245 }
246 body.extend_from_slice(b"\r\n");
247 body.extend_from_slice(&part.value);
248 body.extend_from_slice(b"\r\n");
249 }
250 body.extend_from_slice(format!("--{MULTIPART_BOUNDARY}--\r\n").as_bytes());
251
252 let content_type = HeaderValue::from_str(&format!(
253 "multipart/form-data; boundary={MULTIPART_BOUNDARY}"
254 ))
255 .map_err(|_| Error::internal("failed to build multipart content type"))?;
256 let pending = PendingBody {
257 content_type: Some(content_type),
258 bytes: Bytes::from(body),
259 };
260 self.shared
261 .send(self.method, self.path, self.query, self.headers, pending)
262 .await
263 }
264}
265
266pub(crate) const CONTENT_TYPE_HEADER: HeaderName = CONTENT_TYPE;
268
269#[cfg(test)]
270mod tests {
271 use super::super::client::{Shared, Transport};
272 use super::*;
273 use crate::app::App;
274 use std::sync::Mutex;
275
276 #[derive(serde::Serialize)]
277 struct Query {
278 word: &'static str,
279 }
280
281 struct BrokenSerialize;
282
283 impl Serialize for BrokenSerialize {
284 fn serialize<S>(&self, _serializer: S) -> std::result::Result<S::Ok, S::Error>
285 where
286 S: serde::Serializer,
287 {
288 Err(serde::ser::Error::custom("boom"))
289 }
290 }
291
292 fn shared() -> Arc<Shared> {
293 Arc::new(Shared {
294 transport: Transport::InProcess(Arc::new(App::new().build().unwrap())),
295 default_headers: http::HeaderMap::new(),
296 unsafe_default_headers: http::HeaderMap::new(),
297 cookies: Mutex::new(super::super::cookie::CookieJar::default()),
298 })
299 }
300
301 #[test]
302 fn json_and_form_reset_body_on_serialize_failure() {
303 let request = TestRequestBuilder::new(shared(), Method::POST, "/items")
304 .json(&BrokenSerialize)
305 .form(&BrokenSerialize);
306
307 assert!(request.body.content_type.is_none());
308 assert!(request.body.bytes.is_empty());
309 }
310
311 #[test]
312 fn builder_collects_headers_query_and_bytes() {
313 let request = TestRequestBuilder::new(shared(), Method::PUT, "/items")
314 .header("x-test", "1")
315 .header("\n", "ignored")
316 .query("q", "space value")
317 .bytes(Bytes::from_static(b"payload"));
318
319 assert_eq!(request.headers.len(), 1);
320 assert_eq!(
321 request.query,
322 vec![("q".to_owned(), "space value".to_owned())]
323 );
324 assert_eq!(request.body.bytes, Bytes::from_static(b"payload"));
325 assert!(request.body.content_type.is_none());
326 assert!(!request.headers[0].unsafe_allowed);
327 }
328
329 #[test]
330 fn unsafe_header_marks_the_entry() {
331 let request = TestRequestBuilder::new(shared(), Method::GET, "/items")
332 .unsafe_header("host", "example.com");
333
334 assert_eq!(request.headers.len(), 1);
335 assert!(request.headers[0].unsafe_allowed);
336 }
337
338 #[tokio::test]
339 async fn multipart_builder_encodes_text_and_file_parts() {
340 let response = TestRequestBuilder::new(shared(), Method::POST, "/upload")
341 .multipart()
342 .text("title", "hello")
343 .file_bytes("file", "note.txt", "text/plain", "payload")
344 .query("kind", "docs")
345 .header("x-test", "1")
346 .send()
347 .await
348 .unwrap();
349
350 assert_eq!(response.status(), 404);
351 }
352
353 #[test]
354 fn form_uses_urlencoding() {
355 let request = TestRequestBuilder::new(shared(), Method::POST, "/search").form(&Query {
356 word: "hello world",
357 });
358
359 assert_eq!(
360 request.body.content_type,
361 Some(HeaderValue::from_static(
362 "application/x-www-form-urlencoded"
363 ))
364 );
365 assert_eq!(request.body.bytes, Bytes::from_static(b"word=hello+world"));
366 }
367}