1use std::net::SocketAddr;
4use std::pin::Pin;
5use std::sync::{Arc, Mutex};
6
7use bytes::Bytes;
8use http::header::HOST;
9use http::{HeaderMap, HeaderName, HeaderValue, Method, StatusCode};
10use http_body_util::{BodyExt, Full};
11use hyper_util::rt::TokioIo;
12use tokio::net::TcpStream;
13use tokio::sync::oneshot;
14use tokio::task::JoinHandle;
15
16use super::cookie::CookieJar;
17use super::recorder::LogRecorder;
18use super::request::{PendingBody, TestRequestBuilder};
19use super::response::TestResponse;
20use super::websocket::TestWebSocketBuilder;
21use super::TestOverrides;
22use crate::app::{App, AppInner, TestApp};
23use crate::body::{box_body, BoxError, ReqBody};
24use crate::error::{Error, Result};
25use crate::state::StateMap;
26
27pub(crate) type StreamingBody =
29 Pin<Box<dyn http_body::Body<Data = Bytes, Error = BoxError> + Send>>;
30
31type ResourceRegister = Box<dyn FnOnce(&mut StateMap)>;
33
34#[derive(Clone)]
36pub(crate) struct TestHeader {
37 pub(crate) name: HeaderName,
38 pub(crate) value: HeaderValue,
39 pub(crate) unsafe_allowed: bool,
40}
41
42impl TestHeader {
43 pub(crate) fn safe(name: HeaderName, value: HeaderValue) -> Self {
44 Self {
45 name,
46 value,
47 unsafe_allowed: false,
48 }
49 }
50
51 pub(crate) fn unsafe_allowed(name: HeaderName, value: HeaderValue) -> Self {
52 Self {
53 name,
54 value,
55 unsafe_allowed: true,
56 }
57 }
58}
59
60const SENSITIVE_TEST_HEADERS: [&str; 5] = [
61 "host",
62 "forwarded",
63 "x-forwarded-for",
64 "x-forwarded-host",
65 "x-forwarded-proto",
66];
67
68pub(crate) enum Transport {
73 InProcess(Arc<AppInner>),
75 RealPort(SocketAddr),
77}
78
79impl Transport {
80 pub(crate) fn address(&self) -> Option<SocketAddr> {
82 match self {
83 Transport::InProcess(_) => None,
84 Transport::RealPort(addr) => Some(*addr),
85 }
86 }
87
88 pub(crate) async fn execute(
90 &self,
91 request: http::Request<ReqBody>,
92 ) -> Result<(StatusCode, HeaderMap, Bytes)> {
93 match self {
94 Transport::InProcess(app) => {
95 let response = app.clone().handle(request).await;
96 let (parts, body) = response.into_parts();
97 let bytes = collect_body(body).await?;
98 Ok((parts.status, parts.headers, bytes))
99 }
100 Transport::RealPort(addr) => {
101 let response = send_over_socket(*addr, request).await?;
102 let (parts, body) = response.into_parts();
103 let bytes = collect_body(body).await?;
104 Ok((parts.status, parts.headers, bytes))
105 }
106 }
107 }
108
109 pub(crate) async fn execute_streaming(
111 &self,
112 request: http::Request<ReqBody>,
113 ) -> Result<(StatusCode, HeaderMap, StreamingBody)> {
114 match self {
115 Transport::InProcess(app) => {
116 let response = app.clone().handle(request).await;
117 let (parts, body) = response.into_parts();
118 Ok((parts.status, parts.headers, Box::pin(body)))
119 }
120 Transport::RealPort(addr) => {
121 let response = send_over_socket(*addr, request).await?;
122 let (parts, body) = response.into_parts();
123 let boxed: StreamingBody =
124 Box::pin(body.map_err(|error| Box::new(error) as BoxError));
125 Ok((parts.status, parts.headers, boxed))
126 }
127 }
128 }
129}
130
131async fn collect_body<B>(body: B) -> Result<Bytes>
133where
134 B: http_body::Body<Data = Bytes>,
135 B::Error: std::fmt::Display,
136{
137 let collected = body
138 .collect()
139 .await
140 .map_err(|error| Error::internal(format!("failed to read response body: {error}")))?;
141 Ok(collected.to_bytes())
142}
143
144async fn send_over_socket(
146 addr: SocketAddr,
147 mut request: http::Request<ReqBody>,
148) -> Result<http::Response<hyper::body::Incoming>> {
149 if !request.headers().contains_key(HOST) {
151 if let Ok(value) = HeaderValue::from_str(&addr.to_string()) {
152 request.headers_mut().insert(HOST, value);
153 }
154 }
155
156 let stream = TcpStream::connect(addr)
157 .await
158 .map_err(|error| Error::internal(format!("failed to connect to {addr}: {error}")))?;
159 let io = TokioIo::new(stream);
160 let (mut sender, connection) = hyper::client::conn::http1::handshake(io)
161 .await
162 .map_err(|error| Error::internal(format!("client handshake failed: {error}")))?;
163 tokio::spawn(async move {
164 let _ = connection.await;
165 });
166 sender
167 .send_request(request)
168 .await
169 .map_err(|error| Error::internal(format!("request failed: {error}")))
170}
171
172pub(crate) struct Shared {
174 pub(crate) transport: Transport,
175 pub(crate) default_headers: HeaderMap,
176 pub(crate) unsafe_default_headers: HeaderMap,
177 pub(crate) cookies: Mutex<CookieJar>,
178}
179
180impl Shared {
181 pub(crate) async fn send(
183 &self,
184 method: Method,
185 path: String,
186 query: Vec<(String, String)>,
187 headers: Vec<TestHeader>,
188 body: PendingBody,
189 ) -> Result<TestResponse> {
190 let request = self.build_request(method, &path, &query, headers, body)?;
191 let (status, headers, bytes) = self.transport.execute(request).await?;
192 self.cookies
193 .lock()
194 .expect("cookie jar mutex poisoned")
195 .store(&headers);
196 Ok(TestResponse {
197 status,
198 headers,
199 body: bytes,
200 })
201 }
202
203 pub(crate) async fn open_sse(
206 &self,
207 method: Method,
208 path: String,
209 query: Vec<(String, String)>,
210 headers: Vec<TestHeader>,
211 ) -> Result<super::sse::TestSseStream> {
212 let request = self.build_request(method, &path, &query, headers, PendingBody::default())?;
213 let (_status, headers, body) = self.transport.execute_streaming(request).await?;
214 self.cookies
215 .lock()
216 .expect("cookie jar mutex poisoned")
217 .store(&headers);
218 Ok(super::sse::TestSseStream::new(body))
219 }
220
221 pub(crate) fn build_request(
224 &self,
225 method: Method,
226 path: &str,
227 query: &[(String, String)],
228 headers: Vec<TestHeader>,
229 body: PendingBody,
230 ) -> Result<http::Request<ReqBody>> {
231 let uri = if query.is_empty() {
232 path.to_owned()
233 } else {
234 let encoded = serde_urlencoded::to_string(query)
235 .map_err(|_| Error::internal("failed to encode query parameters"))?;
236 format!("{path}?{encoded}")
237 };
238
239 let mut request = http::Request::new(box_body(Full::new(body.bytes)));
240 *request.method_mut() = method;
241 *request.uri_mut() = uri
242 .parse()
243 .map_err(|_| Error::bad_request(format!("invalid request URI: {uri}")))?;
244
245 self.reject_in_process_sensitive_headers(&headers)?;
246
247 let map = request.headers_mut();
248 for (name, value) in self.default_headers.iter() {
249 map.insert(name, value.clone());
250 }
251 for (name, value) in self.unsafe_default_headers.iter() {
252 map.insert(name, value.clone());
253 }
254 for header in headers {
255 map.insert(header.name, header.value);
256 }
257 self.cookies
258 .lock()
259 .expect("cookie jar mutex poisoned")
260 .apply(map);
261 if let Some(content_type) = body.content_type {
262 map.insert(super::request::CONTENT_TYPE_HEADER, content_type);
263 }
264
265 Ok(request)
266 }
267
268 pub(crate) fn reject_in_process_sensitive_headers(&self, headers: &[TestHeader]) -> Result<()> {
269 if !matches!(self.transport, Transport::InProcess(_)) {
270 return Ok(());
271 }
272
273 let mut blocked = Vec::new();
274 for header in headers {
275 if !header.unsafe_allowed && is_sensitive_test_header(&header.name) {
276 blocked.push(header.name.as_str().to_owned());
277 }
278 }
279 if blocked.is_empty() {
280 Ok(())
281 } else {
282 Err(sensitive_header_error(&blocked))
283 }
284 }
285}
286
287fn is_sensitive_test_header(name: &HeaderName) -> bool {
288 SENSITIVE_TEST_HEADERS
289 .iter()
290 .any(|candidate| *candidate == name.as_str())
291}
292
293fn sensitive_header_error(headers: &[String]) -> Error {
294 Error::bad_request(format!(
295 "in-process test clients reject security-sensitive header(s): {}; use unsafe_header/unsafe_default_header or TestClient::serve(...).bind_random_port()",
296 headers.join(", ")
297 ))
298 .with_code("TEST_UNSAFE_HEADER_REQUIRES_OPT_IN")
299}
300
301pub struct TestClient {
308 shared: Arc<Shared>,
309 teardown: Teardown,
310 _log_guard: Option<tracing::subscriber::DefaultGuard>,
313}
314
315enum Teardown {
317 InProcess(Box<TestApp>),
319 RealPort {
321 shutdown: Option<oneshot::Sender<()>>,
322 handle: JoinHandle<()>,
323 },
324}
325
326impl TestClient {
327 pub async fn new(app: TestApp) -> Result<Self> {
329 Ok(Self {
330 shared: Arc::new(Shared {
331 transport: Transport::InProcess(app.inner.clone()),
332 default_headers: HeaderMap::new(),
333 unsafe_default_headers: HeaderMap::new(),
334 cookies: Mutex::new(CookieJar::default()),
335 }),
336 teardown: Teardown::InProcess(Box::new(app)),
337 _log_guard: None,
338 })
339 }
340
341 pub fn builder(app: App) -> TestClientBuilder {
343 TestClientBuilder::new(app)
344 }
345
346 pub fn serve(app: App) -> ServeBuilder {
348 ServeBuilder { app }
349 }
350
351 pub fn local_addr(&self) -> Option<SocketAddr> {
353 self.shared.transport.address()
354 }
355
356 pub fn websocket(&self, path: &str) -> TestWebSocketBuilder {
358 TestWebSocketBuilder::new(self.shared.clone(), path)
359 }
360
361 pub fn get(&self, path: &str) -> TestRequestBuilder {
363 TestRequestBuilder::new(self.shared.clone(), Method::GET, path)
364 }
365
366 pub fn post(&self, path: &str) -> TestRequestBuilder {
368 TestRequestBuilder::new(self.shared.clone(), Method::POST, path)
369 }
370
371 pub fn put(&self, path: &str) -> TestRequestBuilder {
373 TestRequestBuilder::new(self.shared.clone(), Method::PUT, path)
374 }
375
376 pub fn patch(&self, path: &str) -> TestRequestBuilder {
378 TestRequestBuilder::new(self.shared.clone(), Method::PATCH, path)
379 }
380
381 pub fn delete(&self, path: &str) -> TestRequestBuilder {
383 TestRequestBuilder::new(self.shared.clone(), Method::DELETE, path)
384 }
385
386 pub async fn shutdown(self) -> Result<()> {
389 match self.teardown {
390 Teardown::InProcess(app) => app.shutdown().await,
391 Teardown::RealPort { shutdown, handle } => {
392 if let Some(sender) = shutdown {
393 let _ = sender.send(());
394 }
395 let _ = handle.await;
396 Ok(())
397 }
398 }
399 }
400}
401
402pub struct ServeBuilder {
408 app: App,
409}
410
411impl ServeBuilder {
412 pub async fn bind_random_port(self) -> Result<TestClient> {
414 let (addr_tx, addr_rx) = oneshot::channel::<Result<SocketAddr>>();
415 let (shutdown_tx, shutdown_rx) = oneshot::channel();
416 let sender = Arc::new(Mutex::new(Some(addr_tx)));
417 let ready_sender = sender.clone();
418
419 let app = self.app.on_ready(move |ctx| {
420 let sender = ready_sender.clone();
421 async move {
422 if let Some(tx) = sender.lock().expect("address sender mutex poisoned").take() {
423 let _ = tx.send(Ok(ctx.addr()));
424 }
425 Ok(())
426 }
427 });
428
429 let sender = sender.clone();
430 let handle = tokio::spawn(async move {
431 let result = app
432 .serve_with_shutdown("127.0.0.1:0", async move {
433 let _ = shutdown_rx.await;
434 })
435 .await;
436 if let (Err(error), Some(tx)) = (
437 result,
438 sender.lock().expect("address sender mutex poisoned").take(),
439 ) {
440 let _ = tx.send(Err(error));
441 }
442 });
443
444 let addr = addr_rx
445 .await
446 .map_err(|_| Error::internal("the test server failed to start"))??;
447
448 Ok(TestClient {
449 shared: Arc::new(Shared {
450 transport: Transport::RealPort(addr),
451 default_headers: HeaderMap::new(),
452 unsafe_default_headers: HeaderMap::new(),
453 cookies: Mutex::new(CookieJar::default()),
454 }),
455 teardown: Teardown::RealPort {
456 shutdown: Some(shutdown_tx),
457 handle,
458 },
459 _log_guard: None,
460 })
461 }
462}
463
464pub struct TestClientBuilder {
467 app: App,
468 resources: Vec<ResourceRegister>,
469 overrides: TestOverrides,
470 default_headers: HeaderMap,
471 unsafe_default_headers: HeaderMap,
472 blocked_sensitive_headers: Vec<String>,
473 cookies: CookieJar,
474 recorder: Option<LogRecorder>,
475}
476
477impl TestClientBuilder {
478 fn new(app: App) -> Self {
479 Self {
480 app,
481 resources: Vec::new(),
482 overrides: TestOverrides::default(),
483 default_headers: HeaderMap::new(),
484 unsafe_default_headers: HeaderMap::new(),
485 blocked_sensitive_headers: Vec::new(),
486 cookies: CookieJar::default(),
487 recorder: None,
488 }
489 }
490
491 pub fn logger(mut self, recorder: LogRecorder) -> Self {
496 self.recorder = Some(recorder);
497 self
498 }
499
500 pub fn resource<S: Send + Sync + 'static>(mut self, value: S) -> Self {
503 self.resources
504 .push(Box::new(move |state| state.insert(value)));
505 self
506 }
507
508 pub fn override_dependency<T: Clone + Send + Sync + 'static>(mut self, value: T) -> Self {
510 self.overrides.insert::<T, _>(move || value.clone());
511 self
512 }
513
514 pub fn override_dependency_with<T, F>(mut self, factory: F) -> Self
516 where
517 T: Send + 'static,
518 F: Fn() -> T + Send + Sync + 'static,
519 {
520 self.overrides.insert::<T, F>(factory);
521 self
522 }
523
524 pub fn default_header(mut self, name: &str, value: &str) -> Self {
526 if let (Ok(name), Ok(value)) = (
527 HeaderName::from_bytes(name.as_bytes()),
528 HeaderValue::from_str(value),
529 ) {
530 if is_sensitive_test_header(&name) {
531 self.blocked_sensitive_headers
532 .push(name.as_str().to_owned());
533 } else {
534 self.default_headers.insert(name, value);
535 }
536 }
537 self
538 }
539
540 pub fn unsafe_default_header(mut self, name: &str, value: &str) -> Self {
542 if let (Ok(name), Ok(value)) = (
543 HeaderName::from_bytes(name.as_bytes()),
544 HeaderValue::from_str(value),
545 ) {
546 self.unsafe_default_headers.insert(name, value);
547 }
548 self
549 }
550
551 pub fn cookie(mut self, name: &str, value: &str) -> Self {
553 self.cookies.set(name, value);
554 self
555 }
556
557 pub async fn build(self) -> Result<TestClient> {
559 let resources = self.resources;
560 let overrides = self.overrides;
561 let default_headers = self.default_headers;
562 let unsafe_default_headers = self.unsafe_default_headers;
563 let blocked_sensitive_headers = self.blocked_sensitive_headers;
564 let cookies = self.cookies;
565 let recorder = self.recorder;
566
567 if !blocked_sensitive_headers.is_empty() {
568 return Err(sensitive_header_error(&blocked_sensitive_headers));
569 }
570
571 let app = self
572 .app
573 .build_test_with(move |state| {
574 for register in resources {
575 register(state);
576 }
577 if !overrides.is_empty() {
578 state.insert(overrides);
579 }
580 })
581 .await?;
582
583 let log_guard = recorder.map(|recorder| {
585 use tracing_subscriber::layer::SubscriberExt;
586 let subscriber = tracing_subscriber::registry().with(recorder);
587 tracing::subscriber::set_default(subscriber)
588 });
589
590 Ok(TestClient {
591 shared: Arc::new(Shared {
592 transport: Transport::InProcess(app.inner.clone()),
593 default_headers,
594 unsafe_default_headers,
595 cookies: Mutex::new(cookies),
596 }),
597 teardown: Teardown::InProcess(Box::new(app)),
598 _log_guard: log_guard,
599 })
600 }
601}
602
603#[cfg(test)]
604mod tests {
605 use super::*;
606 use crate::app::App;
607 use crate::body::{BoxError, RespBody};
608 use crate::response::Response as TorkResponse;
609 use crate::router::{BoxFuture, HandlerFn, Route, Router};
610 use bytes::Bytes;
611 use futures_util::stream;
612 use http::header::{CONTENT_TYPE, COOKIE};
613 use http_body::Frame;
614 use http_body_util::{BodyExt, StreamBody};
615 use std::sync::Arc;
616
617 fn json_handler() -> HandlerFn {
618 Arc::new(|_ctx: crate::extract::RequestContext| -> BoxFuture<'static, crate::Result<TorkResponse>> {
619 Box::pin(async { Ok(crate::json_response(crate::StatusCode::OK, &serde_json::json!({ "ok": true }))) })
620 })
621 }
622
623 fn stream_handler() -> HandlerFn {
624 Arc::new(|_ctx: crate::extract::RequestContext| -> BoxFuture<'static, crate::Result<TorkResponse>> {
625 Box::pin(async {
626 let frames = stream::iter(vec![
627 Ok::<_, BoxError>(Frame::data(Bytes::from_static(b"one"))),
628 Ok(Frame::data(Bytes::from_static(b"two"))),
629 ]);
630 let body = RespBody::stream(StreamBody::new(frames));
631 let mut response = TorkResponse::new(body);
632 *response.status_mut() = crate::StatusCode::OK;
633 response.headers_mut().insert(
634 CONTENT_TYPE,
635 http::HeaderValue::from_static("text/event-stream"),
636 );
637 Ok(response)
638 })
639 })
640 }
641
642 fn shared() -> Shared {
643 let mut default_headers = HeaderMap::new();
644 default_headers.insert("x-default", HeaderValue::from_static("on"));
645 let mut cookies = CookieJar::default();
646 cookies.set("sid", "abc");
647 Shared {
648 transport: Transport::InProcess(Arc::new(App::new().build().unwrap())),
649 default_headers,
650 unsafe_default_headers: HeaderMap::new(),
651 cookies: Mutex::new(cookies),
652 }
653 }
654
655 #[test]
656 fn build_request_merges_defaults_headers_cookies_and_content_type() {
657 let request = shared()
658 .build_request(
659 Method::POST,
660 "/items",
661 &[("q".to_owned(), "hello world".to_owned())],
662 vec![TestHeader::safe(
663 HeaderName::from_static("x-custom"),
664 HeaderValue::from_static("yes"),
665 )],
666 PendingBody {
667 content_type: Some(HeaderValue::from_static("application/json")),
668 bytes: Bytes::from_static(b"{}"),
669 },
670 )
671 .unwrap();
672
673 assert_eq!(request.uri(), "/items?q=hello+world");
674 assert_eq!(request.headers()["x-default"], "on");
675 assert_eq!(request.headers()["x-custom"], "yes");
676 assert_eq!(request.headers()[COOKIE], "sid=abc");
677 assert_eq!(request.headers()[CONTENT_TYPE], "application/json");
678 }
679
680 #[test]
681 fn build_request_rejects_invalid_uri() {
682 let error = shared()
683 .build_request(
684 Method::GET,
685 "http://[",
686 &[],
687 Vec::new(),
688 PendingBody::default(),
689 )
690 .unwrap_err();
691
692 assert_eq!(error.kind(), crate::error::ErrorKind::BadRequest);
693 assert!(error.message().starts_with("invalid request URI:"));
694 }
695
696 #[test]
697 fn build_request_rejects_sensitive_headers_in_process_without_opt_in() {
698 let error = shared()
699 .build_request(
700 Method::GET,
701 "/items",
702 &[],
703 vec![TestHeader::safe(
704 HeaderName::from_static("host"),
705 HeaderValue::from_static("example.com"),
706 )],
707 PendingBody::default(),
708 )
709 .unwrap_err();
710 assert_eq!(error.code(), "TEST_UNSAFE_HEADER_REQUIRES_OPT_IN");
711 assert!(error.message().contains("host"));
712 }
713
714 #[test]
715 fn build_request_allows_sensitive_headers_with_opt_in() {
716 let request = shared()
717 .build_request(
718 Method::GET,
719 "/items",
720 &[],
721 vec![TestHeader::unsafe_allowed(
722 HeaderName::from_static("host"),
723 HeaderValue::from_static("example.com"),
724 )],
725 PendingBody::default(),
726 )
727 .unwrap();
728 assert_eq!(request.headers()["host"], "example.com");
729 }
730
731 #[tokio::test]
732 async fn real_port_transport_exercises_execute_and_execute_streaming() {
733 let app = App::new().include_router(
734 Router::new()
735 .route(Route::new(Method::GET, "/json", json_handler()))
736 .route(Route::new(Method::GET, "/stream", stream_handler())),
737 );
738 let client = TestClient::serve(app).bind_random_port().await.unwrap();
739
740 assert!(client.local_addr().is_some());
741 assert!(client.shared.transport.address().is_some());
742
743 let request = client
744 .shared
745 .build_request(
746 Method::GET,
747 "/json",
748 &[],
749 Vec::new(),
750 PendingBody::default(),
751 )
752 .unwrap();
753 let (status, headers, bytes) = client.shared.transport.execute(request).await.unwrap();
754 assert_eq!(status, StatusCode::OK);
755 assert_eq!(headers[CONTENT_TYPE], "application/json");
756 assert!(bytes.contains(&b'o'));
757
758 let request = client
759 .shared
760 .build_request(
761 Method::GET,
762 "/stream",
763 &[],
764 Vec::new(),
765 PendingBody::default(),
766 )
767 .unwrap();
768 let (status, headers, mut body) = client
769 .shared
770 .transport
771 .execute_streaming(request)
772 .await
773 .unwrap();
774 assert_eq!(status, StatusCode::OK);
775 assert_eq!(headers[CONTENT_TYPE], "text/event-stream");
776 let mut saw_data = false;
777 while let Some(frame) = body.frame().await {
778 let frame = frame.unwrap();
779 if frame.into_data().is_ok() {
780 saw_data = true;
781 }
782 }
783 assert!(saw_data);
784
785 client.shutdown().await.unwrap();
786 }
787}